|
|
|
|
|
""" |
|
|
Unit 4 API Client for GAIA Benchmark Questions |
|
|
Handles question fetching, file downloads, and answer submission |
|
|
""" |
|
|
|
|
|
import os |
|
|
import requests |
|
|
import logging |
|
|
from typing import Dict, Any, List, Optional, Union |
|
|
from dataclasses import dataclass |
|
|
import json |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class GAIAQuestion: |
|
|
"""GAIA benchmark question data structure""" |
|
|
task_id: str |
|
|
question: str |
|
|
level: int |
|
|
final_answer: Optional[str] = None |
|
|
file_name: Optional[str] = None |
|
|
file_path: Optional[str] = None |
|
|
metadata: Optional[Dict[str, Any]] = None |
|
|
|
|
|
@dataclass |
|
|
class SubmissionResult: |
|
|
"""Result of answer submission""" |
|
|
task_id: str |
|
|
submitted_answer: str |
|
|
success: bool |
|
|
score: Optional[float] = None |
|
|
feedback: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
class Unit4APIClient: |
|
|
"""Client for Unit 4 API to fetch GAIA questions and submit answers""" |
|
|
|
|
|
def __init__(self, base_url: str = "https://agents-course-unit4-scoring.hf.space"): |
|
|
"""Initialize Unit 4 API client""" |
|
|
self.base_url = base_url.rstrip('/') |
|
|
self.session = requests.Session() |
|
|
self.session.headers.update({ |
|
|
'User-Agent': 'GAIA-Agent-System/1.0', |
|
|
'Accept': 'application/json', |
|
|
'Content-Type': 'application/json' |
|
|
}) |
|
|
|
|
|
|
|
|
self.downloads_dir = Path("downloads") |
|
|
self.downloads_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.requests_made = 0 |
|
|
self.last_request_time = 0 |
|
|
self.rate_limit_delay = 1.0 |
|
|
|
|
|
def _rate_limit(self): |
|
|
"""Implement basic rate limiting""" |
|
|
current_time = time.time() |
|
|
time_since_last = current_time - self.last_request_time |
|
|
|
|
|
if time_since_last < self.rate_limit_delay: |
|
|
sleep_time = self.rate_limit_delay - time_since_last |
|
|
logger.debug(f"Rate limiting: sleeping {sleep_time:.2f}s") |
|
|
time.sleep(sleep_time) |
|
|
|
|
|
self.last_request_time = time.time() |
|
|
self.requests_made += 1 |
|
|
|
|
|
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: |
|
|
"""Make HTTP request with rate limiting and error handling""" |
|
|
self._rate_limit() |
|
|
|
|
|
url = f"{self.base_url}{endpoint}" |
|
|
|
|
|
try: |
|
|
logger.debug(f"Making {method} request to {url}") |
|
|
response = self.session.request(method, url, **kwargs) |
|
|
response.raise_for_status() |
|
|
return response |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
logger.error(f"API request failed: {e}") |
|
|
raise |
|
|
|
|
|
def get_questions(self, level: Optional[int] = None, limit: Optional[int] = None) -> List[GAIAQuestion]: |
|
|
"""Fetch GAIA questions from the API""" |
|
|
|
|
|
endpoint = "/questions" |
|
|
params = {} |
|
|
|
|
|
if level is not None: |
|
|
params['level'] = level |
|
|
if limit is not None: |
|
|
params['limit'] = limit |
|
|
|
|
|
try: |
|
|
response = self._make_request('GET', endpoint, params=params) |
|
|
data = response.json() |
|
|
|
|
|
questions = [] |
|
|
|
|
|
|
|
|
if isinstance(data, list): |
|
|
question_list = data |
|
|
elif isinstance(data, dict) and 'questions' in data: |
|
|
question_list = data['questions'] |
|
|
else: |
|
|
question_list = [data] |
|
|
|
|
|
for q_data in question_list: |
|
|
question = GAIAQuestion( |
|
|
task_id=q_data.get('task_id', ''), |
|
|
question=q_data.get('question', ''), |
|
|
level=q_data.get('level', 1), |
|
|
final_answer=q_data.get('final_answer'), |
|
|
file_name=q_data.get('file_name'), |
|
|
metadata=q_data |
|
|
) |
|
|
questions.append(question) |
|
|
|
|
|
logger.info(f"✅ Fetched {len(questions)} questions from API") |
|
|
return questions |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to fetch questions: {e}") |
|
|
return [] |
|
|
|
|
|
def get_random_question(self, level: Optional[int] = None) -> Optional[GAIAQuestion]: |
|
|
"""Fetch a random question from the API""" |
|
|
|
|
|
endpoint = "/random-question" |
|
|
params = {} |
|
|
|
|
|
if level is not None: |
|
|
params['level'] = level |
|
|
|
|
|
try: |
|
|
response = self._make_request('GET', endpoint, params=params) |
|
|
data = response.json() |
|
|
|
|
|
question = GAIAQuestion( |
|
|
task_id=data.get('task_id', ''), |
|
|
question=data.get('question', ''), |
|
|
level=data.get('level', 1), |
|
|
final_answer=data.get('final_answer'), |
|
|
file_name=data.get('file_name'), |
|
|
metadata=data |
|
|
) |
|
|
|
|
|
logger.info(f"✅ Fetched random question: {question.task_id}") |
|
|
return question |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to fetch random question: {e}") |
|
|
return None |
|
|
|
|
|
def download_file(self, task_id: str, file_name: Optional[str] = None) -> Optional[str]: |
|
|
"""Download file associated with a question""" |
|
|
|
|
|
if not task_id: |
|
|
logger.error("Task ID required for file download") |
|
|
return None |
|
|
|
|
|
endpoint = f"/files/{task_id}" |
|
|
|
|
|
try: |
|
|
response = self._make_request('GET', endpoint, stream=True) |
|
|
|
|
|
|
|
|
if file_name: |
|
|
filename = file_name |
|
|
else: |
|
|
|
|
|
content_disposition = response.headers.get('content-disposition', '') |
|
|
if 'filename=' in content_disposition: |
|
|
filename = content_disposition.split('filename=')[1].strip('"') |
|
|
else: |
|
|
|
|
|
filename = f"{task_id}_file" |
|
|
|
|
|
|
|
|
file_path = self.downloads_dir / filename |
|
|
|
|
|
with open(file_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
|
|
|
logger.info(f"✅ Downloaded file: {file_path}") |
|
|
return str(file_path) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to download file for {task_id}: {e}") |
|
|
return None |
|
|
|
|
|
def submit_answer(self, task_id: str, answer: str) -> SubmissionResult: |
|
|
"""Submit answer for evaluation""" |
|
|
|
|
|
endpoint = "/submit" |
|
|
|
|
|
payload = { |
|
|
"task_id": task_id, |
|
|
"answer": str(answer).strip() |
|
|
} |
|
|
|
|
|
try: |
|
|
response = self._make_request('POST', endpoint, json=payload) |
|
|
data = response.json() |
|
|
|
|
|
result = SubmissionResult( |
|
|
task_id=task_id, |
|
|
submitted_answer=answer, |
|
|
success=True, |
|
|
score=data.get('score'), |
|
|
feedback=data.get('feedback'), |
|
|
) |
|
|
|
|
|
logger.info(f"✅ Submitted answer for {task_id}") |
|
|
if result.score is not None: |
|
|
logger.info(f" Score: {result.score}") |
|
|
if result.feedback: |
|
|
logger.info(f" Feedback: {result.feedback}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to submit answer for {task_id}: {e}") |
|
|
|
|
|
return SubmissionResult( |
|
|
task_id=task_id, |
|
|
submitted_answer=answer, |
|
|
success=False, |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
def validate_answer_format(self, answer: str, question: GAIAQuestion) -> bool: |
|
|
"""Validate answer format before submission""" |
|
|
|
|
|
if not answer or not answer.strip(): |
|
|
logger.warning("Empty answer provided") |
|
|
return False |
|
|
|
|
|
|
|
|
if len(answer) > 1000: |
|
|
logger.warning("Answer is very long (>1000 chars)") |
|
|
|
|
|
|
|
|
cleaned_answer = answer.strip() |
|
|
|
|
|
|
|
|
logger.debug(f"Answer validation passed for {question.task_id}") |
|
|
return True |
|
|
|
|
|
def get_api_status(self) -> Dict[str, Any]: |
|
|
"""Check API status and endpoints""" |
|
|
|
|
|
status = { |
|
|
"base_url": self.base_url, |
|
|
"requests_made": self.requests_made, |
|
|
"endpoints_tested": {} |
|
|
} |
|
|
|
|
|
|
|
|
test_endpoints = [ |
|
|
("/questions", "GET"), |
|
|
("/random-question", "GET"), |
|
|
] |
|
|
|
|
|
for endpoint, method in test_endpoints: |
|
|
try: |
|
|
response = self._make_request(method, endpoint, timeout=5) |
|
|
status["endpoints_tested"][endpoint] = { |
|
|
"status_code": response.status_code, |
|
|
"success": True |
|
|
} |
|
|
except Exception as e: |
|
|
status["endpoints_tested"][endpoint] = { |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
return status |
|
|
|
|
|
def process_question_with_files(self, question: GAIAQuestion) -> GAIAQuestion: |
|
|
"""Process question and download associated files if needed""" |
|
|
|
|
|
if question.file_name and question.task_id: |
|
|
logger.info(f"Downloading file for question {question.task_id}") |
|
|
file_path = self.download_file(question.task_id, question.file_name) |
|
|
|
|
|
if file_path: |
|
|
question.file_path = file_path |
|
|
logger.info(f"✅ File ready: {file_path}") |
|
|
else: |
|
|
logger.warning(f"❌ Failed to download file for {question.task_id}") |
|
|
|
|
|
return question |
|
|
|
|
|
|
|
|
def test_api_connection(): |
|
|
"""Test basic API connectivity""" |
|
|
logger.info("🧪 Testing Unit 4 API connection...") |
|
|
|
|
|
client = Unit4APIClient() |
|
|
|
|
|
|
|
|
status = client.get_api_status() |
|
|
logger.info("📊 API Status:") |
|
|
for endpoint, result in status["endpoints_tested"].items(): |
|
|
status_str = "✅ PASS" if result["success"] else "❌ FAIL" |
|
|
logger.info(f" {endpoint:20}: {status_str}") |
|
|
if not result["success"]: |
|
|
logger.info(f" Error: {result.get('error', 'Unknown')}") |
|
|
|
|
|
return status |
|
|
|
|
|
def test_question_fetching(): |
|
|
"""Test fetching questions from API""" |
|
|
logger.info("🧪 Testing question fetching...") |
|
|
|
|
|
client = Unit4APIClient() |
|
|
|
|
|
|
|
|
question = client.get_random_question() |
|
|
if question: |
|
|
logger.info(f"✅ Random question fetched: {question.task_id}") |
|
|
logger.info(f" Level: {question.level}") |
|
|
logger.info(f" Question: {question.question[:100]}...") |
|
|
logger.info(f" Has file: {question.file_name is not None}") |
|
|
|
|
|
|
|
|
if question.file_name: |
|
|
question = client.process_question_with_files(question) |
|
|
|
|
|
return question |
|
|
else: |
|
|
logger.error("❌ Failed to fetch random question") |
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
test_api_connection() |
|
|
test_question_fetching() |