File size: 3,923 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import argparse
import os
import requests
import torch
import sys
from pathlib import Path
from unsloth import FastLanguageModel

# Add project root to sys.path
root_dir = Path(__file__).resolve().parent.parent
if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))

from commitguard_env.grpo_prompt import get_agent_prompt, SYSTEM_PROMPT

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", default="outputs/commitguard-llama-3b/final")
    parser.add_argument("--test_file", default="data/devign_test.jsonl")
    parser.add_argument("--env_url", default="http://localhost:8000")
    parser.add_argument("--max_samples", type=int, default=100)
    args = parser.parse_args()

    print(f"Loading merged model from {args.model_path}...")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_path,
        max_seq_length=2048,
        load_in_4bit=True,
        fast_inference=False,
    )
    FastLanguageModel.for_inference(model)

    with open(args.test_file, "r") as f:
        test_samples = [json.loads(line) for line in f]
    
    test_samples = test_samples[:args.max_samples]
    
    results = []
    correct_binary = 0
    total_reward = 0.0

    print(f"Starting evaluation on {len(test_samples)} samples...")

    for i, sample in enumerate(test_samples):
        sample_id = sample["sample_id"]
        # Reset environment for specific sample
        r = requests.post(f"{args.env_url}/reset", json={"sample_id": sample_id})
        if r.status_code != 200:
            print(f"Failed to reset for {sample_id}")
            continue
            
        obs = r.json()["observation"]
        done = False
        step_idx = 0
        episode_reward = 0.0
        
        print(f"[{i+1}/{len(test_samples)}] Evaluating {sample_id}...", end="\r")
        
        while not done and step_idx < 5:
            prompt = get_agent_prompt(obs["diff"], obs["available_files"], obs["step_idx"])
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ]
            inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
            
            with torch.no_grad():
                outputs = model.generate(input_ids=inputs, max_new_tokens=512, use_cache=True)
            
            generated_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
            
            # Step in environment
            r = requests.post(f"{args.env_url}/step", json={"action": generated_text})
            if r.status_code != 200:
                print(f"\nError at step {step_idx}: {r.text}")
                break
                
            res = r.json()
            obs = res["observation"]
            episode_reward = res["reward"]
            done = res["done"]
            step_idx += 1

        total_reward += episode_reward
        # is_correct info is passed in 'info' field by environment
        is_correct = res.get("info", {}).get("is_correct", False)
        if is_correct:
            correct_binary += 1
            
        results.append({
            "sample_id": sample_id,
            "reward": episode_reward,
            "is_correct": is_correct,
            "ground_truth": sample.get("is_vulnerable")
        })

    avg_reward = total_reward / len(test_samples)
    accuracy = (correct_binary / len(test_samples)) * 100
    
    print(f"\nEvaluation Complete!")
    print(f"Average Reward: {avg_reward:.4f}")
    print(f"Accuracy: {accuracy:.2f}%")

    with open("eval_results_trained.json", "w") as f:
        json.dump({
            "average_reward": avg_reward,
            "accuracy": accuracy,
            "detailed_results": results
        }, f, indent=2)

if __name__ == "__main__":
    main()