Pranav Pc
Final Deploy
4b82ab5
from transformers import RobertaTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import json
from pathlib import Path
class VulnerabilityDataset(Dataset):
"""PyTorch dataset for vulnerability detection"""
def __init__(self, data_path, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = []
data_path = Path(data_path)
if not data_path.exists():
raise FileNotFoundError(f"Dataset file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self.data.append(json.loads(line))
print(f"{data_path.name}: {len(self.data)} samples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
code = sample["func"] # confirmed correct
label = sample["target"] # confirmed correct (0/1)
encoding = self.tokenizer(
code,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(label, dtype=torch.long)
}
def load_tokenizer(model_name="Salesforce/codet5-base"):
print(f"Tokenizer: {model_name}")
return RobertaTokenizer.from_pretrained(model_name)
def create_dataloader(
train_path,
valid_path,
test_path,
tokenizer,
batch_size=8,
max_length=512,
num_workers=2,
):
train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length)
valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length)
test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length)
if len(train_dataset) == 0:
raise RuntimeError(f"No samples found in {train_path}")
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
return train_loader, valid_loader, test_loader