commitguard-env / scripts /evaluate.py
Nitishkumar-ai's picture
Deployment Build (Final): Professional Structure + Blog
95cbc5b
import json
import argparse
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from pathlib import Path
import sys
# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from agent_prompt import SYSTEM_PROMPT
from commitguard_env.parse_action import parse_action
def format_eval_prompt(sample):
# Matches the format_prompt in train_grpo.py for consistency
return (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
f"Analyze this commit and submit your verdict.\n\n"
f"Code diff:\n```diff\n{sample['diff']}\n```<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
def evaluate(model_path, test_file, is_lora=False, base_model=None, output_file="eval_results.json"):
"""
Run model on test samples, compute accuracy metrics.
"""
print(f"Loading model from {model_path}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
if is_lora:
if not base_model:
raise ValueError("base_model is required if is_lora=True")
print(f"Loading LoRA adapter from {model_path} with base model {base_model}")
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = base_model,
max_seq_length = 2048,
load_in_4bit = True,
)
model = PeftModel.from_pretrained(model, model_path)
FastLanguageModel.for_inference(model)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load test data
print(f"Loading test data from {test_file}...")
with open(test_file, "r", encoding="utf-8") as f:
samples = [json.loads(line) for line in f if line.strip()]
results = {
"summary": {
"total": len(samples),
"correct_binary": 0,
"correct_cwe": 0,
"false_positives": 0,
"false_negatives": 0,
"binary_accuracy": 0,
"cwe_accuracy": 0,
"false_positive_rate": 0,
"false_negative_rate": 0,
"cwe_breakdown": {},
},
"predictions": [],
}
print(f"Starting evaluation on {len(samples)} samples...")
for i, sample in enumerate(samples):
prompt = format_eval_prompt(sample)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=False,
)
response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Use the official parser from commitguard_env
prediction = parse_action(response)
gt_vulnerable = bool(sample["is_vulnerable"])
pred_vulnerable = bool(prediction.is_vulnerable) if prediction.is_vulnerable is not None else False
correct = pred_vulnerable == gt_vulnerable
if correct:
results["summary"]["correct_binary"] += 1
if gt_vulnerable and not pred_vulnerable:
results["summary"]["false_negatives"] += 1
elif not gt_vulnerable and pred_vulnerable:
results["summary"]["false_positives"] += 1
cwe = sample.get("cwe") or "CWE-OTHER"
if cwe not in results["summary"]["cwe_breakdown"]:
results["summary"]["cwe_breakdown"][cwe] = {"total": 0, "correct": 0, "accuracy": 0}
results["summary"]["cwe_breakdown"][cwe]["total"] += 1
if correct:
results["summary"]["cwe_breakdown"][cwe]["correct"] += 1
if gt_vulnerable and correct and prediction.vuln_type and prediction.vuln_type.strip().upper() == cwe.strip().upper():
results["summary"]["correct_cwe"] += 1
results["predictions"].append({
"sample_id": sample["sample_id"],
"ground_truth": gt_vulnerable,
"predicted": pred_vulnerable,
"predicted_cwe": prediction.vuln_type,
"actual_cwe": cwe,
"response": response,
})
if (i + 1) % 10 == 0:
print(f" Processed {i+1}/{len(samples)} samples...")
# Final summary stats
summary = results["summary"]
total = summary["total"]
vuln_count = sum(1 for s in samples if s["is_vulnerable"])
safe_count = total - vuln_count
tp = summary["correct_binary"] - (safe_count - summary["false_positives"]) # TP = correct vuln predictions
# Recompute from confusion matrix
tp = vuln_count - summary["false_negatives"]
fp = summary["false_positives"]
fn = summary["false_negatives"]
tn = safe_count - fp
summary["tp"] = tp
summary["fp"] = fp
summary["tn"] = tn
summary["fn"] = fn
summary["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
summary["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
p, r = summary["precision"], summary["recall"]
summary["f1"] = 2 * p * r / (p + r) if (p + r) > 0 else 0
summary["binary_accuracy"] = summary["correct_binary"] / total if total > 0 else 0
summary["cwe_accuracy"] = summary["correct_cwe"] / vuln_count if vuln_count > 0 else 0
summary["false_positive_rate"] = summary["false_positives"] / safe_count if safe_count > 0 else 0
summary["false_negative_rate"] = summary["false_negatives"] / vuln_count if vuln_count > 0 else 0
for cwe in summary["cwe_breakdown"]:
stats = summary["cwe_breakdown"][cwe]
stats["accuracy"] = stats["correct"] / stats["total"] if stats["total"] > 0 else 0
print(f"\nEvaluation Complete:")
print(f" Binary Accuracy: {summary['binary_accuracy']:.2%}")
print(f" Precision: {summary['precision']:.2%}")
print(f" Recall: {summary['recall']:.2%}")
print(f" F1 Score: {summary['f1']:.2%}")
print(f" CWE Accuracy: {summary['cwe_accuracy']:.2%}")
print(f" Confusion: TP={summary['tp']} FP={summary['fp']} TN={summary['tn']} FN={summary['fn']}")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {output_file}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--test-file", default="data/devign_test.jsonl")
parser.add_argument("--is-lora", action="store_true")
parser.add_argument("--base-model", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--output", default="eval_results.json")
args = parser.parse_args()
evaluate(args.model_path, args.test_file, args.is_lora, args.base_model, args.output)