gitgud-ai / app /main.py
CodeCommunity's picture
Update app/main.py
2781db5 verified
import os
import re
import logging
import traceback
import time
from typing import List, Optional, Dict
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
# Load environment variables
load_dotenv()
from app.predictor import classifier, guide_generator, reviewer
# Note: AIReviewerService from the first version is typically
# the underlying service for the 'reviewer' object in the second.
# 1. Setup Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 2. Initialize FastAPI
app = FastAPI(title="GitGud AI Service")
main = app # Alias for compatibility
# Global embedding cache
# Structure: { "repo_name": { "file_path": [embedding_vector] } }
REPO_CACHE: Dict[str, Dict[str, List[float]]] = {}
# 3. Data Models
class FileRequest(BaseModel):
fileName: str
content: Optional[str] = None
repoName: Optional[str] = None
class BatchReviewRequest(BaseModel):
files: List[FileRequest]
class GuideRequest(BaseModel):
repoName: str
filePaths: List[str]
class SearchRequest(BaseModel):
query: str
embeddings: Optional[Dict[str, List[float]]] = None # Path -> Embedding
repoName: Optional[str] = None
class ChatRequest(BaseModel):
query: str
context: List[Dict[str, str]] # List of { "fileName": str, "content": str }
repoName: str
# 4. Endpoints
@app.get("/")
def health_check():
"""Checks server status, GPU availability, and cached data."""
return {
"status": "online",
"model": "microsoft/codebert-base",
"device": getattr(classifier, "device", "cpu"),
"cached_repos": list(REPO_CACHE.keys()),
}
@app.get("/usage")
def get_usage():
"""Returns AI Service usage statistics."""
from app.core.model_loader import llm_engine
return llm_engine.get_usage_stats()
@app.post("/classify")
async def classify_file(request: FileRequest):
"""Classifies file into architectural layers and caches embeddings."""
try:
result = classifier.predict(request.fileName, request.content)
# Cache embedding if repoName is provided
if request.repoName:
if request.repoName not in REPO_CACHE:
REPO_CACHE[request.repoName] = {}
REPO_CACHE[request.repoName][request.fileName] = result["embedding"]
return {
"fileName": request.fileName,
"layer": result["label"],
"confidence": result["confidence"],
"embedding": result["embedding"]
}
except Exception as e:
logger.error(f"Classify failed: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/review-batch-code")
async def review_batch_code(request: BatchReviewRequest):
"""Batch review with detailed metrics and suggestions."""
try:
reviews = reviewer.service.review_batch_code(request.files)
total_files = len(reviews)
total_vulns = sum(len(r.get("vulnerabilities", [])) for r in reviews)
# Calculate Average Maintainability
m_scores = [r.get("metrics", {}).get("maintainability", 8) for r in reviews]
avg_maint = sum(m_scores) / max(total_files, 1)
return {
"totalFiles": total_files,
"totalVulnerabilities": total_vulns,
"averageMaintainability": round(avg_maint, 1),
"results": reviews,
}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/repo-dashboard-stats")
async def get_dashboard_stats(request: BatchReviewRequest):
"""Aggregated stats for frontend dashboards including health and API sniffing."""
try:
raw_reviews = reviewer.service.review_batch_code(request.files)
# 1. Security Count
total_vulns = sum(len(r.get("vulnerabilities", [])) for r in raw_reviews)
# 2. Performance/Maintainability Ratio
scores = [r.get("metrics", {}).get("maintainability", 8) for r in raw_reviews]
avg_maintainability = (sum(scores) / len(scores)) * 10 if scores else 0
# 3. API Sniffing (Regex)
found_apis = []
for f in request.files:
if f.content:
matches = re.findall(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', f.content.lower())
for match in matches:
found_apis.append(f"/{match}")
# 4. Repo Health Calculation
health_score = max(10, 100 - (total_vulns * 10))
return {
"repo_health": health_score,
"health_label": "Excellent Health" if health_score > 80 else "Needs Review",
"security_issues": total_vulns,
"performance_ratio": f"{int(avg_maintainability)}%",
"exposed_apis": list(set(found_apis))[:10]
}
except Exception as e:
logger.error(f"Dashboard stats failed: {e}")
raise HTTPException(status_code=500, detail="Failed to aggregate repository stats")
@app.post("/analyze-file")
async def analyze_file(request: FileRequest):
"""Deep analysis: Summary, Tags, and Layer Classification."""
try:
result = classifier.predict(request.fileName, request.content)
summary = classifier.generate_file_summary(request.content, request.fileName)
tags = classifier.extract_tags(request.content, request.fileName)
if request.repoName:
if request.repoName not in REPO_CACHE:
REPO_CACHE[request.repoName] = {}
REPO_CACHE[request.repoName][request.fileName] = result["embedding"]
return {
"fileName": request.fileName,
"layer": result["label"],
"summary": summary,
"tags": tags,
"embedding": result["embedding"],
}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/semantic-search")
async def semantic_search(request: SearchRequest):
"""Search code using natural language and vector similarity."""
try:
embeddings = request.embeddings
if not embeddings and request.repoName and request.repoName in REPO_CACHE:
embeddings = REPO_CACHE[request.repoName]
if not embeddings:
return {"results": []}
results = classifier.semantic_search(request.query, embeddings)
return {"results": results}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat")
async def chat(request: ChatRequest):
"""RAG-based chat using provided file context."""
start_time = time.time()
logger.info(f"Received Chat Request for {request.repoName}")
try:
from app.core.model_loader import llm_engine
context_str = ""
for item in request.context:
context_str += f"--- FILE: {item['fileName']} ---\n{item['content']}\n\n"
has_context = len(request.context) > 0
prompt = f"""
You are "GitGud AI", an expert software architect.
Repository: "{request.repoName}"
INSTRUCTIONS:
1. Use the provided CONTEXT to answer.
2. If context is missing, state: "I am using general knowledge as I don't have specific snippets for this."
3. Use markdown for code.
CONTEXT:
{context_str if has_context else "[(NO CODE SNIPPETS PROVIDED)]"}
USER QUESTION:
{request.query}
"""
response = llm_engine.generate_text(prompt)
logger.info(f"Chat response generated in {time.time() - start_time:.2f}s")
return {"response": response}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate-guide")
async def generate_guide(request: GuideRequest):
"""Generates markdown documentation for the repo."""
try:
markdown = guide_generator.generate_markdown(request.repoName, request.filePaths)
return {"markdown": markdown}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# 5. Application Entry Point
if __name__ == "__main__":
# Note: Using 7860 for HF Spaces compatibility, change to 8000 if preferred for local dev
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)