Pranav Pc
Final Deploy
4b82ab5
"""CodeT5 Vulnerability Detection model
Binary Classication Safe(0) vs Vulnerable(1)"""
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, RobertaTokenizer
class VulnerabilityCodeT5(nn.Module):
"""CodeT5 model for vulnerability detection"""
def __init__(self, model_name="Salesforce/codet5-base", num_labels=2):
super().__init__()
self.encoder_decoder = T5ForConditionalGeneration.from_pretrained(model_name)
#Get hidden size from config
hidden_size = self.encoder_decoder.config.d_model #768 for base
#Classification Head
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_labels)
)
self.num_labels = num_labels
def forward(self, input_ids, attention_mask, labels=None):
"""
Forward pass
Args:
input_ids : tokenized code [batch_size, seq_len]
attention_mask : attention mask [batch_size, seq_len]
labels: ground truth labels [batch_size]
"""
#Get encoder outputs
encoder_outputs = self.encoder_decoder.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
#Pool encoder outputs (use first token [CLS])
hidden_state = encoder_outputs.last_hidden_state # [batch, seq_len, hidden]
pooled_output = hidden_state[:, 0, :] # [batch, hidden]
#Classification
logits = self.classifier(pooled_output) # [batch, num_labels]
#Calculate loss
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {
'loss': loss,
'logits': logits,
'hidden_states': hidden_state
}
def predict(self, input_ids, attention_mask):
"""Make Predictions"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
probs = torch.softmax(outputs["logits"], dim=1)
predictions = torch.argmax(probs, dim=1)
return predictions, probs
def count_parameters(model):
"""Count trainable parameters"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)