gitgud-ai / app /predictor.py
CodeCommunity's picture
Create app/predictor.py
da07baf verified
# for the loading of CODE_BERT model
import logging
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# just use to log at terminal no big deal here
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CodeClassifier:
# (initializing/downloading/caching) the codeBERT_model first time it will take some time as downloading (500_MB)
def __init__(self):
logger.info("⏳ Initializing AI Service...")
# Detect Hardware (Uses GPU on Victus, OR fallback if GPU is not available CPU/MPS)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
self.device = "mps"
logger.info(f"🚀 Running on device: {self.device}")
# This will download 'microsoft/codebert-base' from Hugging Face
# the first time it runs. It caches it locally afterwards.
try:
logger.info(
"📥 Loading CodeBERT Model (this may take a minute first time)..."
)
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
self.model = AutoModel.from_pretrained("microsoft/codebert-base").to(
self.device
)
logger.info("✅ CodeBERT Loaded Successfully!")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise e
# Initialize Semantic Anchors for Classification
self.labels = {
"Frontend": "import react component from view styles css html dom window document state props effect",
"Backend": "import express nest controller service entity repository database sql mongoose route api async await req res dto",
"Security": "import auth passport jwt strategy bcrypt verify token secret guard password user login session middleware",
"DevOps": "docker build image container kubernetes yaml env port host volume deploy pipeline stage steps runs-on",
"Testing": "describe it expect test mock spy jest beforeall aftereach suite spec assert",
}
self.label_embeddings = self._precompute_label_embeddings()
def _get_embedding(self, text: str):
"""
Generates a 768-dim vector for the given text using CodeBERT.
"""
# Truncate text to avoid model errors (CodeBERT max is 512 tokens)
inputs = self.tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Use the [CLS] token embedding (index 0) as the sentence representation
return outputs.last_hidden_state[:, 0, :]
def _precompute_label_embeddings(self):
"""
Computes embeddings for the category descriptions once at startup.
"""
logger.info("🧠 Pre-computing semantic anchors for classification...")
embeddings = {}
for label, description in self.labels.items():
embeddings[label] = self._get_embedding(description)
return embeddings
# for pridicting our file name themes and then classiyfing them into related catagory
def predict(self, file_path: str, content: str = None) -> dict:
"""
Determines the 'Layer' of a file (Frontend, Backend, etc.)
Returns: { "label": str, "confidence": float, "embedding": list[float] }
"""
path = file_path.lower()
# Helper to standardize return format
def result(label, conf=1.0, emb=None):
return {
"label": label,
"confidence": conf,
"embedding": emb if emb is not None else [],
}
# 1. Fast Path: Rule-Based Checks (High Precision)
# We still keep this because it's instant and correct for obvious things.
# Frontend Indicators
if any(
x in path
for x in [
"/components/",
"/pages/",
"/views/",
".jsx",
".tsx",
".css",
"tailwind",
]
):
return result("Frontend")
# Backend Indicators
if any(
x in path
for x in [
"/controllers/",
"/modules/",
"/services/",
".controller.ts",
".service.ts",
"dto",
]
):
return result("Backend")
# Security Indicators
if any(
x in path
for x in ["auth", "guard", "strategy", "jwt", "passport", "middleware"]
):
return result("Security")
# DevOps/Config Indicators
if any(
x in path
for x in ["docker", "k8s", "github/workflows", "tsconfig", "package.json"]
):
return result("DevOps")
# Testing Indicators
if any(x in path for x in ["test", "spec", "e2e", "jest"]):
return result("Testing")
# 2. Slow Path: AI-Powered Semantic Classification (High Recall)
# If the rules are unsure ("Generic"), we ask the model.
try:
# Decide what to analyze: Content (best) or Path (fallback)
text_to_analyze = content if content else file_path
# If analyzing content, take the first 1000 chars (approx enough tokens) to capture imports/class defs
if content:
text_to_analyze = content[:1000]
target_embedding_tensor = self._get_embedding(text_to_analyze)
target_embedding_list = (
target_embedding_tensor.tolist()
) # Convert to list for JSON serialization
best_label = "Generic"
highest_score = -1.0
for label, anchor_embedding in self.label_embeddings.items():
# Cosine Similarity: -1 to 1
score = F.cosine_similarity(
target_embedding_tensor, anchor_embedding
).item()
if score > highest_score:
highest_score = score
best_label = label
# Only accept the AI's opinion if it's somewhat confident
if highest_score > 0.25:
return result(best_label, highest_score, target_embedding_list)
return result("Generic", highest_score, target_embedding_list)
except Exception as e:
logger.error(f"AI Classification failed for {file_path}: {e}")
return result("Generic", 0.0)
class GuideGenerator:
def __init__(self):
self.tech_stacks = {
"React": ["react", "jsx", "tsx", "next.config.js"],
"Vue": ["vue", "nuxt.config.js"],
"Angular": ["angular.json"],
"Svelte": ["svelte.config.js"],
"NestJS": ["nest-cli.json", ".module.ts"],
"Express": ["express", "server.js", "app.js"],
"FastAPI": ["fastapi", "main.py"],
"Django": ["django", "manage.py"],
"Flask": ["flask", "app.py"],
"Spring Boot": ["pom.xml", "build.gradle", "src/main/java"],
"Go": ["go.mod", "main.go"],
"Rust": ["Cargo.toml", "src/main.rs"],
}
self.tools = {
"Docker": ["Dockerfile", "docker-compose.yml"],
"Kubernetes": ["k8s", "helm", "charts/"],
"TypeScript": ["tsconfig.json", ".ts"],
"Tailwind CSS": ["tailwind.config.js"],
"Prisma": ["schema.prisma"],
"GraphQL": [".graphql", "schema.gql"],
"PostgreSQL": ["postgresql", "pg"],
"MongoDB": ["mongoose", "mongodb"],
"Redis": ["redis"],
}
def detect_stack(self, files: list[str]) -> dict:
detected = {"languages": set(), "frameworks": set(), "tools": set()}
for file in files:
path = file.lower()
# Languages
if path.endswith(".ts") or path.endswith(".tsx"):
detected["languages"].add("TypeScript")
elif path.endswith(".js") or path.endswith(".jsx"):
detected["languages"].add("JavaScript")
elif path.endswith(".py"):
detected["languages"].add("Python")
elif path.endswith(".go"):
detected["languages"].add("Go")
elif path.endswith(".rs"):
detected["languages"].add("Rust")
elif path.endswith(".java"):
detected["languages"].add("Java")
# Frameworks
for framework, indicators in self.tech_stacks.items():
if any(ind in path for ind in indicators):
detected["frameworks"].add(framework)
# Tools
for tool, indicators in self.tools.items():
if any(ind in path for ind in indicators):
detected["tools"].add(tool)
return detected
def _generate_tree(self, files: list[str]) -> str:
"""
Generates a clean ASCII tree of the project structure with architectural annotations.
"""
tree = {}
relevant_files = []
# 1. Filter and normalize paths
for f in files:
parts = f.split("/")
# Skip noise
if any(
p
in [
"node_modules",
".git",
"__pycache__",
"dist",
"build",
".idea",
".vscode",
]
for p in parts
):
continue
if f.endswith(".DS_Store"):
continue
relevant_files.append(f)
# 2. Build nested dictionary
for path in relevant_files:
parts = path.split("/")
if len(parts) > 3:
parts = parts[:3]
current = tree
for part in parts:
current = current.setdefault(part, {})
# 3. Define Annotations
descriptions = {
"src": "Core application source code",
"app": "Main application logic",
"components": "Reusable UI components",
"pages": "Route/Page definitions",
"api": "API endpoints and services",
"utils": "Utility functions and helpers",
"lib": "External libraries and configurations",
"test": "Unit and integration tests",
"tests": "Test suites",
"docs": "Project documentation",
"public": "Static assets (images, fonts)",
"assets": "Static media files",
"server": "Backend server code",
"client": "Frontend client application",
"config": "Configuration files",
"scripts": "Build and maintenance scripts",
"prisma": "Database schema and migrations",
"graphql": "GraphQL definitions",
}
# 4. Render tree
lines = []
def render(node, prefix=""):
keys = sorted(node.keys())
# Priority sorting: put 'src', 'app', 'server', 'client' first
priority = ["src", "app", "client", "server", "public"]
keys.sort(key=lambda k: (0 if k in priority else 1, k))
if len(keys) > 12:
keys = keys[:12] + ["..."]
for i, key in enumerate(keys):
is_last = i == len(keys) - 1
connector = "└── " if is_last else "├── "
# Add annotation if available and it's a folder (has children)
comment = ""
if key in descriptions and isinstance(node[key], dict) and node[key]:
comment = f" # {descriptions[key]}"
lines.append(f"{prefix}{connector}{key}{comment}")
if isinstance(node.get(key), dict) and node[key]:
extension = " " if is_last else "│ "
render(node[key], prefix + extension)
render(tree)
return "\n".join(lines[:50])
def generate_markdown(self, repo_name: str, files: list[str]) -> str:
# 1. Perform AI Analysis (DNA of the repo)
stats = {
"Frontend": 0,
"Backend": 0,
"Security": 0,
"DevOps": 0,
"Testing": 0,
"Generic": 0,
}
layer_map = {}
low_confidence_files = []
file_embeddings = {} # Path -> Tensor
for f in files:
# We use CodeBERT's path-embedding capability here since we don't have content for all files
prediction = classifier.predict(f)
layer = prediction["label"]
confidence = prediction["confidence"]
stats[layer] += 1
layer_map[f] = layer
if confidence < 0.4 and layer != "Generic":
low_confidence_files.append((f, confidence))
# Store embedding for coupling analysis (if available)
if prediction["embedding"] and len(prediction["embedding"]) > 0:
file_embeddings[f] = torch.tensor(prediction["embedding"])
total_files = len(files) if files else 1
primary_layer = max(stats, key=stats.get)
# Calculate Semantic Couplings (Top 5)
couplings = []
try:
paths = list(file_embeddings.keys())
# Limit to first 50 files for safety and performance
sample_paths = paths[:50]
for i in range(len(sample_paths)):
for j in range(i + 1, len(sample_paths)):
p1, p2 = sample_paths[i], sample_paths[j]
# Skip same folder
if p1.rsplit("/", 1)[0] == p2.rsplit("/", 1)[0]:
continue
# Ensure tensors are valid and on CPU for comparison
t1 = file_embeddings[p1].cpu()
t2 = file_embeddings[p2].cpu()
score = F.cosine_similarity(t1.unsqueeze(0), t2.unsqueeze(0)).item()
if score > 0.88:
couplings.append((p1, p2, score))
except Exception as e:
logger.error(f"Failed to calculate couplings: {e}")
couplings.sort(key=lambda x: x[2], reverse=True)
top_couplings = couplings[:5]
# Sort low confidence by lowest score
low_confidence_files.sort(key=lambda x: x[1])
top_refactors = low_confidence_files[:5]
# 2. Advanced Stack & Feature Detection
stack = self.detect_stack(files)
features = self._detect_features(files, stats)
dev_tools = self._detect_dev_tools(files)
install_cmd = "npm install"
run_cmd = "npm run dev"
test_cmd = "npm test"
if "Python" in stack["languages"]:
install_cmd = "pip install -r requirements.txt"
run_cmd = "python main.py"
test_cmd = "pytest"
elif "Go" in stack["languages"]:
install_cmd = "go mod download"
run_cmd = "go run main.go"
test_cmd = "go test ./..."
elif "Rust" in stack["languages"]:
install_cmd = "cargo build"
run_cmd = "cargo run"
test_cmd = "cargo test"
if "Django" in stack["frameworks"]:
run_cmd = "python manage.py runserver"
test_cmd = "python manage.py test"
elif "Spring Boot" in stack["frameworks"]:
install_cmd = "./mvnw install"
run_cmd = "./mvnw spring-boot:run"
if "Docker" in stack["tools"]:
run_cmd += "\n# Or using Docker\ndocker-compose up --build"
# 3. Assemble Guide
md = f"# {repo_name} Developer Guide\n\n"
md += "## AI Codebase Insights\n"
md += f"Analysis powered by **CodeBERT** semantic vectors.\n\n"
md += f"**Project DNA:** {self._get_project_dna(stats, total_files)}\n\n"
md += f"**Quality Check:** {self._get_testing_status(stats, total_files)}\n\n"
if top_refactors:
md += "### Code Health & Complexity\n"
md += "The AI flagged the following files as **Non-Standard** or **Complex** (Low Confidence).\n"
md += (
"These are good candidates for refactoring or documentation reviews:\n"
)
for f, score in top_refactors:
md += f"- `{f}` (Confidence: {int(score * 100)}%)\n"
md += "\n"
if top_couplings:
md += "### Logical Couplings\n"
md += "The AI detected strong semantic connections between these file pairs (they share logic but not folders):\n"
for p1, p2, score in top_couplings:
md += f"- `{p1}` <--> `{p2}` ({int(score * 100)}% match)\n"
md += "\n"
md += "### Layer Composition\n"
md += "| Layer | Composition | Status |\n"
md += "| :--- | :--- | :--- |\n"
for layer, count in stats.items():
if count > 0:
percentage = (count / total_files) * 100
status = "Primary" if layer == primary_layer else "Detected"
md += f"| {layer} | {percentage:.1f}% | {status} |\n"
md += "\n"
md += "## Key Features\n"
if features:
md += "The following capabilities were inferred from the codebase structure:\n\n"
for feature, description in features.items():
md += f"- **{feature}**: {description}\n"
else:
md += "No specific high-level features (Auth, Database, etc.) were explicitly detected from the file structure.\n"
md += "\n"
md += "## Architecture & Technologies\n"
md += "The project utilizes the following core technologies:\n\n"
if stack["languages"]:
md += "**Languages**: " + ", ".join(sorted(stack["languages"])) + "\n"
if stack["frameworks"]:
md += "**Frameworks**: " + ", ".join(sorted(stack["frameworks"])) + "\n"
if stack["tools"]:
md += "**Infrastructure**: " + ", ".join(sorted(stack["tools"])) + "\n"
if dev_tools:
md += "**Development Tools**: " + ", ".join(sorted(dev_tools)) + "\n"
md += "\n"
md += "## Getting Started\n\n"
# Config Section
if any(f.endswith(".env") or f.endswith(".env.example") for f in files):
md += "### Configuration\n"
md += "This project relies on environment variables. \n"
md += "1. Find the `.env.example` file in the root directory.\n"
md += "2. Copy it to create a new `.env` file.\n"
md += " `cp .env.example .env`\n"
md += "3. Fill in the required values (Database URL, API Keys, etc.).\n\n"
md += "### Prerequisites\n"
md += "Ensure you have the following installed:\n"
md += "- Git\n"
if any(x in ["TypeScript", "JavaScript"] for x in stack["languages"]):
md += "- Node.js (LTS)\n"
if "Python" in stack["languages"]:
md += "- Python 3.8+\n"
if "Docker" in stack["tools"]:
md += "- Docker Desktop\n"
md += "\n### Installation\n"
md += "1. Clone and enter the repository:\n"
md += " ```bash\n"
md += f" git clone https://github.com/OWNER/{repo_name}.git\n"
md += f" cd {repo_name}\n"
md += " ```\n\n"
md += f"2. Install project dependencies:\n"
md += " ```bash\n"
md += f" {install_cmd}\n"
md += " ```\n\n"
md += "### Execution\n"
md += "Start the development environment:\n"
md += "```bash\n"
md += f"{run_cmd}\n"
md += "```\n\n"
if stats["Testing"] > 0:
md += "## Testing\n"
md += "Automated tests were detected. Run them using:\n"
md += f"```bash\n{test_cmd}\n```\n\n"
md += "## Project Structure\n"
md += "Detailed tree view with AI-predicted layer labels:\n\n"
md += "```text\n"
md += self._generate_tree_with_ai(files, layer_map)
md += "\n```\n\n"
# ... (rest of the markdown)
md += "## Contribution Workflow\n\n"
md += "We welcome contributions! Please follow this detailed workflow to ensure smooth collaboration:\n\n"
md += "### 1. Find an Issue\n"
md += "- Browse the **Issues** tab for tasks.\n"
md += "- Look for labels like `good first issue` or `help wanted` if you are new.\n"
md += (
"- Comment on the issue to assign it to yourself before starting work.\n\n"
)
md += "### 2. Branching Strategy\n"
md += "Create a new branch from `main` using one of the following prefixes:\n"
md += "- `feat/`: For new features (e.g., `feat/user-auth`)\n"
md += "- `fix/`: For bug fixes (e.g., `fix/login-error`)\n"
md += "- `docs/`: For documentation changes (e.g., `docs/update-readme`)\n"
md += "- `refactor/`: For code improvements without logic changes\n"
md += "```bash\n"
md += "git checkout -b feat/your-feature-name\n"
md += "```\n\n"
md += "### 3. Development Standards\n"
if dev_tools:
md += "Before committing, ensure your code meets the project standards:\n"
if "ESLint" in dev_tools or "Prettier" in dev_tools:
md += "- **Linting**: Run `npm run lint` to fix style issues.\n"
if "Jest" in dev_tools or "pytest" in dev_tools:
md += "- **Testing**: Run tests to ensure no regressions.\n"
md += "- Keep pull requests small and focused on a single task.\n\n"
md += "### 4. Commit Messages\n"
md += "We follow the **Conventional Commits** specification:\n"
md += "- `feat: add new search component`\n"
md += "- `fix: handle null pointer exception`\n"
md += "- `chore: update dependencies`\n"
md += "```bash\n"
md += "git commit -m 'feat: implement amazing feature'\n"
md += "```\n\n"
md += "### 5. Pull Request Process\n"
md += "1. Push your branch: `git push origin feat/your-feature-name`\n"
md += "2. Open a Pull Request against the `main` branch.\n"
md += "3. Fill out the PR template with details about your changes.\n"
md += "4. Wait for code review and address any feedback.\n"
md += "5. Once approved, your changes will be merged!\n\n"
md += "## About this Guide\n"
md += "This documentation was automatically generated by **GitGud AI**.\n"
md += "- **Engine:** CodeBERT (Transformer-based Language Model)\n"
md += "- **Analysis:** Zero-shot semantic classification of file paths and content.\n"
md += "- **Accuracy:** The 'Project DNA' and 'Layer Composition' metrics are derived from vector embeddings of your codebase, providing a mathematical approximation of your architecture.\n"
return md
def _get_project_dna(self, stats: dict, total: int) -> str:
"""
Interprets the layer statistics to give a high-level project description.
"""
backend_pct = (stats["Backend"] / total) * 100
frontend_pct = (stats["Frontend"] / total) * 100
ops_pct = (stats["DevOps"] / total) * 100
if backend_pct > 50:
return "This project is a **Backend-focused service**, likely an API or microservice. The majority of the codebase is dedicated to business logic and data handling."
elif frontend_pct > 50:
return "This project is a **Frontend-heavy application**, focusing on UI/UX and client-side logic."
elif backend_pct > 30 and frontend_pct > 30:
return "This is a balanced **Full-Stack Application**, containing significant logic for both client and server components."
elif ops_pct > 40:
return "This repository appears to be an **Infrastructure or Configuration** project (IaC), heavily focused on deployment and orchestration."
else:
return "This is a **General purpose repository**, possibly a library or a mix of various utilities."
def _get_testing_status(self, stats: dict, total: int) -> str:
test_pct = (stats["Testing"] / total) * 100
if test_pct > 20:
return "[Excellent] **Excellent Test Coverage**. A significant portion of the codebase is dedicated to testing."
elif test_pct > 5:
return "[Moderate] **Moderate Testing**. Tests are present but may not cover all modules."
else:
return "[Low] **Low Test Coverage**. Very few test files were detected. Recommended to add more unit or integration tests."
def _detect_features(self, files: list[str], stats: dict) -> dict:
features = {}
files_str = " ".join(files).lower()
# Authentication - Only if CodeBERT detected Security logic
if stats["Security"] > 0:
auth_indicators = [
"/auth/",
"/login",
"/register",
"passport",
"jwt",
"session",
"bcrypt",
"strategy",
]
if any(x in files_str for x in auth_indicators):
features["Authentication"] = (
"Implements user authentication (Login/Signup/Session management)."
)
# Database - Only if CodeBERT detected Backend logic
if stats["Backend"] > 0:
db_indicators = [
"schema.prisma",
"models.py",
"migration",
"sequelize",
"typeorm",
"mongoose",
"/db/",
"/database/",
]
if any(x in files_str for x in db_indicators):
features["Database"] = (
"Includes database schema definitions or ORM models."
)
# API
api_indicators = [
"/api/",
"controllers",
"resolvers",
"routes",
"router",
"endpoint",
]
if any(x in files_str for x in api_indicators):
features["API"] = "Exposes RESTful endpoints or GraphQL resolvers."
# Realtime
if any(x in files_str for x in ["socket", "websocket", "io.", "channel"]):
features["Real-time"] = "Uses WebSockets or real-time event channels."
# UI Architecture
if stats["Frontend"] > 0:
ui_indicators = ["components/", "views/", ".tsx", ".jsx", ".vue", "/pages/"]
if any(x in files_str for x in ui_indicators):
features["UI Architecture"] = "Modular component-based user interface."
return features
def _detect_dev_tools(self, files: list[str]) -> set:
tools = set()
files_str = " ".join(files).lower()
if "eslint" in files_str:
tools.add("ESLint")
if "prettier" in files_str:
tools.add("Prettier")
if "jest" in files_str:
tools.add("Jest")
if "cypress" in files_str:
tools.add("Cypress")
if "github/workflows" in files_str:
tools.add("GitHub Actions")
if "husky" in files_str:
tools.add("Husky")
if "tailwind" in files_str:
tools.add("Tailwind CSS")
if "vite" in files_str:
tools.add("Vite")
if "webpack" in files_str:
tools.add("Webpack")
return tools
def _generate_tree_with_ai(self, files: list[str], layer_map: dict) -> str:
tree = {}
for f in files:
parts = f.split("/")
if any(p in ["node_modules", ".git", "__pycache__", "dist"] for p in parts):
continue
curr = tree
for part in parts[:3]: # Depth 3
curr = curr.setdefault(part, {})
lines = []
def render(node, path_prefix="", tree_prefix=""):
keys = sorted(node.keys())
for i, key in enumerate(keys):
is_last = i == len(keys) - 1
full_path = f"{path_prefix}/{key}".strip("/")
# Get layer from map or predict if it's a folder
prediction = classifier.predict(full_path)
layer = layer_map.get(full_path, prediction["label"])
label = f" [{layer}]" if layer != "Generic" else ""
connector = "└── " if is_last else "├── "
lines.append(f"{tree_prefix}{connector}{key}{label}")
if node[key]:
render(
node[key],
full_path,
tree_prefix + (" " if is_last else "│ "),
)
render(tree)
return "\n".join(lines[:60])
# Create the global instance
classifier = CodeClassifier()
guide_generator = GuideGenerator()