|
|
from transformers import RobertaTokenizer |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import torch |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class VulnerabilityDataset(Dataset): |
|
|
"""PyTorch dataset for vulnerability detection""" |
|
|
|
|
|
def __init__(self, data_path, tokenizer, max_length=512): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
self.data = [] |
|
|
data_path = Path(data_path) |
|
|
|
|
|
if not data_path.exists(): |
|
|
raise FileNotFoundError(f"Dataset file not found: {data_path}") |
|
|
|
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
self.data.append(json.loads(line)) |
|
|
|
|
|
print(f"{data_path.name}: {len(self.data)} samples") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sample = self.data[idx] |
|
|
|
|
|
code = sample["func"] |
|
|
label = sample["target"] |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
code, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(0), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(0), |
|
|
"labels": torch.tensor(label, dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
def load_tokenizer(model_name="Salesforce/codet5-base"): |
|
|
print(f"Tokenizer: {model_name}") |
|
|
return RobertaTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
def create_dataloader( |
|
|
train_path, |
|
|
valid_path, |
|
|
test_path, |
|
|
tokenizer, |
|
|
batch_size=8, |
|
|
max_length=512, |
|
|
num_workers=2, |
|
|
): |
|
|
train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length) |
|
|
valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length) |
|
|
test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length) |
|
|
|
|
|
if len(train_dataset) == 0: |
|
|
raise RuntimeError(f"No samples found in {train_path}") |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=True |
|
|
) |
|
|
|
|
|
valid_loader = DataLoader( |
|
|
valid_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=True |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=True |
|
|
) |
|
|
|
|
|
return train_loader, valid_loader, test_loader |
|
|
|