Spaces:
Paused
Paused
| import os | |
| import re | |
| import logging | |
| import traceback | |
| import time | |
| from typing import List, Optional, Dict | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # Load environment variables | |
| load_dotenv() | |
| from app.predictor import classifier, guide_generator, reviewer | |
| # Note: AIReviewerService from the first version is typically | |
| # the underlying service for the 'reviewer' object in the second. | |
| # 1. Setup Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 2. Initialize FastAPI | |
| app = FastAPI(title="GitGud AI Service") | |
| main = app # Alias for compatibility | |
| # Global embedding cache | |
| # Structure: { "repo_name": { "file_path": [embedding_vector] } } | |
| REPO_CACHE: Dict[str, Dict[str, List[float]]] = {} | |
| # 3. Data Models | |
| class FileRequest(BaseModel): | |
| fileName: str | |
| content: Optional[str] = None | |
| repoName: Optional[str] = None | |
| class BatchReviewRequest(BaseModel): | |
| files: List[FileRequest] | |
| class GuideRequest(BaseModel): | |
| repoName: str | |
| filePaths: List[str] | |
| class SearchRequest(BaseModel): | |
| query: str | |
| embeddings: Optional[Dict[str, List[float]]] = None # Path -> Embedding | |
| repoName: Optional[str] = None | |
| class ChatRequest(BaseModel): | |
| query: str | |
| context: List[Dict[str, str]] # List of { "fileName": str, "content": str } | |
| repoName: str | |
| # 4. Endpoints | |
| def health_check(): | |
| """Checks server status, GPU availability, and cached data.""" | |
| return { | |
| "status": "online", | |
| "model": "microsoft/codebert-base", | |
| "device": getattr(classifier, "device", "cpu"), | |
| "cached_repos": list(REPO_CACHE.keys()), | |
| } | |
| def get_usage(): | |
| """Returns AI Service usage statistics.""" | |
| from app.core.model_loader import llm_engine | |
| return llm_engine.get_usage_stats() | |
| async def classify_file(request: FileRequest): | |
| """Classifies file into architectural layers and caches embeddings.""" | |
| try: | |
| result = classifier.predict(request.fileName, request.content) | |
| # Cache embedding if repoName is provided | |
| if request.repoName: | |
| if request.repoName not in REPO_CACHE: | |
| REPO_CACHE[request.repoName] = {} | |
| REPO_CACHE[request.repoName][request.fileName] = result["embedding"] | |
| return { | |
| "fileName": request.fileName, | |
| "layer": result["label"], | |
| "confidence": result["confidence"], | |
| "embedding": result["embedding"] | |
| } | |
| except Exception as e: | |
| logger.error(f"Classify failed: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def review_batch_code(request: BatchReviewRequest): | |
| """Batch review with detailed metrics and suggestions.""" | |
| try: | |
| reviews = reviewer.service.review_batch_code(request.files) | |
| total_files = len(reviews) | |
| total_vulns = sum(len(r.get("vulnerabilities", [])) for r in reviews) | |
| # Calculate Average Maintainability | |
| m_scores = [r.get("metrics", {}).get("maintainability", 8) for r in reviews] | |
| avg_maint = sum(m_scores) / max(total_files, 1) | |
| return { | |
| "totalFiles": total_files, | |
| "totalVulnerabilities": total_vulns, | |
| "averageMaintainability": round(avg_maint, 1), | |
| "results": reviews, | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_dashboard_stats(request: BatchReviewRequest): | |
| """Aggregated stats for frontend dashboards including health and API sniffing.""" | |
| try: | |
| raw_reviews = reviewer.service.review_batch_code(request.files) | |
| # 1. Security Count | |
| total_vulns = sum(len(r.get("vulnerabilities", [])) for r in raw_reviews) | |
| # 2. Performance/Maintainability Ratio | |
| scores = [r.get("metrics", {}).get("maintainability", 8) for r in raw_reviews] | |
| avg_maintainability = (sum(scores) / len(scores)) * 10 if scores else 0 | |
| # 3. API Sniffing (Regex) | |
| found_apis = [] | |
| for f in request.files: | |
| if f.content: | |
| matches = re.findall(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', f.content.lower()) | |
| for match in matches: | |
| found_apis.append(f"/{match}") | |
| # 4. Repo Health Calculation | |
| health_score = max(10, 100 - (total_vulns * 10)) | |
| return { | |
| "repo_health": health_score, | |
| "health_label": "Excellent Health" if health_score > 80 else "Needs Review", | |
| "security_issues": total_vulns, | |
| "performance_ratio": f"{int(avg_maintainability)}%", | |
| "exposed_apis": list(set(found_apis))[:10] | |
| } | |
| except Exception as e: | |
| logger.error(f"Dashboard stats failed: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to aggregate repository stats") | |
| async def analyze_file(request: FileRequest): | |
| """Deep analysis: Summary, Tags, and Layer Classification.""" | |
| try: | |
| result = classifier.predict(request.fileName, request.content) | |
| summary = classifier.generate_file_summary(request.content, request.fileName) | |
| tags = classifier.extract_tags(request.content, request.fileName) | |
| if request.repoName: | |
| if request.repoName not in REPO_CACHE: | |
| REPO_CACHE[request.repoName] = {} | |
| REPO_CACHE[request.repoName][request.fileName] = result["embedding"] | |
| return { | |
| "fileName": request.fileName, | |
| "layer": result["label"], | |
| "summary": summary, | |
| "tags": tags, | |
| "embedding": result["embedding"], | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def semantic_search(request: SearchRequest): | |
| """Search code using natural language and vector similarity.""" | |
| try: | |
| embeddings = request.embeddings | |
| if not embeddings and request.repoName and request.repoName in REPO_CACHE: | |
| embeddings = REPO_CACHE[request.repoName] | |
| if not embeddings: | |
| return {"results": []} | |
| results = classifier.semantic_search(request.query, embeddings) | |
| return {"results": results} | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat(request: ChatRequest): | |
| """RAG-based chat using provided file context.""" | |
| start_time = time.time() | |
| logger.info(f"Received Chat Request for {request.repoName}") | |
| try: | |
| from app.core.model_loader import llm_engine | |
| context_str = "" | |
| for item in request.context: | |
| context_str += f"--- FILE: {item['fileName']} ---\n{item['content']}\n\n" | |
| has_context = len(request.context) > 0 | |
| prompt = f""" | |
| You are "GitGud AI", an expert software architect. | |
| Repository: "{request.repoName}" | |
| INSTRUCTIONS: | |
| 1. Use the provided CONTEXT to answer. | |
| 2. If context is missing, state: "I am using general knowledge as I don't have specific snippets for this." | |
| 3. Use markdown for code. | |
| CONTEXT: | |
| {context_str if has_context else "[(NO CODE SNIPPETS PROVIDED)]"} | |
| USER QUESTION: | |
| {request.query} | |
| """ | |
| response = llm_engine.generate_text(prompt) | |
| logger.info(f"Chat response generated in {time.time() - start_time:.2f}s") | |
| return {"response": response} | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_guide(request: GuideRequest): | |
| """Generates markdown documentation for the repo.""" | |
| try: | |
| markdown = guide_generator.generate_markdown(request.repoName, request.filePaths) | |
| return {"markdown": markdown} | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 5. Application Entry Point | |
| if __name__ == "__main__": | |
| # Note: Using 7860 for HF Spaces compatibility, change to 8000 if preferred for local dev | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |