from transformers import RobertaTokenizer from torch.utils.data import Dataset, DataLoader import torch import json from pathlib import Path class VulnerabilityDataset(Dataset): """PyTorch dataset for vulnerability detection""" def __init__(self, data_path, tokenizer, max_length=512): self.tokenizer = tokenizer self.max_length = max_length self.data = [] data_path = Path(data_path) if not data_path.exists(): raise FileNotFoundError(f"Dataset file not found: {data_path}") with open(data_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: self.data.append(json.loads(line)) print(f"{data_path.name}: {len(self.data)} samples") def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] code = sample["func"] # confirmed correct label = sample["target"] # confirmed correct (0/1) encoding = self.tokenizer( code, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt" ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "labels": torch.tensor(label, dtype=torch.long) } def load_tokenizer(model_name="Salesforce/codet5-base"): print(f"Tokenizer: {model_name}") return RobertaTokenizer.from_pretrained(model_name) def create_dataloader( train_path, valid_path, test_path, tokenizer, batch_size=8, max_length=512, num_workers=2, ): train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length) valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length) test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length) if len(train_dataset) == 0: raise RuntimeError(f"No samples found in {train_path}") train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True ) valid_loader = DataLoader( valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=True ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=True ) return train_loader, valid_loader, test_loader