|
|
"""CodeT5 Vulnerability Detection model |
|
|
Binary Classication Safe(0) vs Vulnerable(1)""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import T5ForConditionalGeneration, RobertaTokenizer |
|
|
|
|
|
class VulnerabilityCodeT5(nn.Module): |
|
|
"""CodeT5 model for vulnerability detection""" |
|
|
|
|
|
def __init__(self, model_name="Salesforce/codet5-base", num_labels=2): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder_decoder = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
hidden_size = self.encoder_decoder.config.d_model |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden_size, num_labels) |
|
|
) |
|
|
|
|
|
self.num_labels = num_labels |
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
|
""" |
|
|
Forward pass |
|
|
Args: |
|
|
input_ids : tokenized code [batch_size, seq_len] |
|
|
attention_mask : attention mask [batch_size, seq_len] |
|
|
labels: ground truth labels [batch_size] |
|
|
""" |
|
|
|
|
|
|
|
|
encoder_outputs = self.encoder_decoder.encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
hidden_state = encoder_outputs.last_hidden_state |
|
|
pooled_output = hidden_state[:, 0, :] |
|
|
|
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
loss = loss_fn(logits, labels) |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'logits': logits, |
|
|
'hidden_states': hidden_state |
|
|
} |
|
|
|
|
|
def predict(self, input_ids, attention_mask): |
|
|
"""Make Predictions""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = self.forward(input_ids, attention_mask) |
|
|
probs = torch.softmax(outputs["logits"], dim=1) |
|
|
predictions = torch.argmax(probs, dim=1) |
|
|
|
|
|
return predictions, probs |
|
|
|
|
|
def count_parameters(model): |
|
|
"""Count trainable parameters""" |
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |