|
|
|
|
|
""" |
|
|
Test Router Agent for GAIA Agent System |
|
|
Tests question classification and agent selection logic |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from agents.state import GAIAAgentState, QuestionType, AgentRole |
|
|
from agents.router import RouterAgent |
|
|
from models.qwen_client import QwenClient |
|
|
|
|
|
def test_router_agent(): |
|
|
"""Test the router agent with various question types""" |
|
|
|
|
|
print("π§ GAIA Router Agent Test") |
|
|
print("=" * 40) |
|
|
|
|
|
|
|
|
try: |
|
|
llm_client = QwenClient() |
|
|
router = RouterAgent(llm_client) |
|
|
except Exception as e: |
|
|
print(f"β Failed to initialize router: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
{ |
|
|
"question": "What is the capital of France?", |
|
|
"expected_type": [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.UNKNOWN], |
|
|
"expected_agents": [AgentRole.WEB_RESEARCHER] |
|
|
}, |
|
|
{ |
|
|
"question": "Calculate 25% of 400 and add 50", |
|
|
"expected_type": [QuestionType.MATHEMATICAL], |
|
|
"expected_agents": [AgentRole.REASONING_AGENT] |
|
|
}, |
|
|
{ |
|
|
"question": "What information can you extract from this CSV file?", |
|
|
"expected_type": [QuestionType.FILE_PROCESSING], |
|
|
"expected_agents": [AgentRole.FILE_PROCESSOR], |
|
|
"has_file": True |
|
|
}, |
|
|
{ |
|
|
"question": "Search for recent news about artificial intelligence", |
|
|
"expected_type": [QuestionType.WEB_RESEARCH], |
|
|
"expected_agents": [AgentRole.WEB_RESEARCHER] |
|
|
}, |
|
|
{ |
|
|
"question": "What does this Python code do and how can it be improved?", |
|
|
"expected_type": [QuestionType.CODE_EXECUTION, QuestionType.FILE_PROCESSING], |
|
|
"expected_agents": [AgentRole.FILE_PROCESSOR, AgentRole.CODE_EXECUTOR], |
|
|
"has_file": True |
|
|
} |
|
|
] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, test_case in enumerate(test_cases, 1): |
|
|
print(f"\n--- Test {i}: {test_case['question'][:50]}... ---") |
|
|
|
|
|
|
|
|
state = GAIAAgentState() |
|
|
state.question = test_case["question"] |
|
|
if test_case.get("has_file"): |
|
|
state.file_name = "test_file.csv" |
|
|
state.file_path = "/tmp/test_file.csv" |
|
|
|
|
|
try: |
|
|
|
|
|
result_state = router.route_question(state) |
|
|
|
|
|
|
|
|
type_correct = result_state.question_type in test_case["expected_type"] |
|
|
agents_correct = any(agent in result_state.selected_agents for agent in test_case["expected_agents"]) |
|
|
|
|
|
success = type_correct and agents_correct |
|
|
results.append(success) |
|
|
|
|
|
print(f" Question Type: {result_state.question_type.value} ({'β
' if type_correct else 'β'})") |
|
|
print(f" Selected Agents: {[a.value for a in result_state.selected_agents]} ({'β
' if agents_correct else 'β'})") |
|
|
print(f" Complexity: {result_state.complexity_assessment}") |
|
|
print(f" Overall: {'β
PASS' if success else 'β FAIL'}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β FAIL: {e}") |
|
|
results.append(False) |
|
|
|
|
|
|
|
|
passed = sum(results) |
|
|
total = len(results) |
|
|
pass_rate = (passed / total) * 100 |
|
|
|
|
|
print("\n" + "=" * 40) |
|
|
print(f"π― ROUTER RESULTS: {passed}/{total} ({pass_rate:.1f}%)") |
|
|
|
|
|
if pass_rate >= 80: |
|
|
print("π Router working correctly!") |
|
|
return True |
|
|
else: |
|
|
print("β οΈ Router needs improvement") |
|
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
success = test_router_agent() |
|
|
sys.exit(0 if success else 1) |