"""CodeT5 Vulnerability Detection model Binary Classication Safe(0) vs Vulnerable(1)""" import torch import torch.nn as nn from transformers import T5ForConditionalGeneration, RobertaTokenizer class VulnerabilityCodeT5(nn.Module): """CodeT5 model for vulnerability detection""" def __init__(self, model_name="Salesforce/codet5-base", num_labels=2): super().__init__() self.encoder_decoder = T5ForConditionalGeneration.from_pretrained(model_name) #Get hidden size from config hidden_size = self.encoder_decoder.config.d_model #768 for base #Classification Head self.classifier = nn.Sequential( nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, num_labels) ) self.num_labels = num_labels def forward(self, input_ids, attention_mask, labels=None): """ Forward pass Args: input_ids : tokenized code [batch_size, seq_len] attention_mask : attention mask [batch_size, seq_len] labels: ground truth labels [batch_size] """ #Get encoder outputs encoder_outputs = self.encoder_decoder.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) #Pool encoder outputs (use first token [CLS]) hidden_state = encoder_outputs.last_hidden_state # [batch, seq_len, hidden] pooled_output = hidden_state[:, 0, :] # [batch, hidden] #Classification logits = self.classifier(pooled_output) # [batch, num_labels] #Calculate loss loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) return { 'loss': loss, 'logits': logits, 'hidden_states': hidden_state } def predict(self, input_ids, attention_mask): """Make Predictions""" self.eval() with torch.no_grad(): outputs = self.forward(input_ids, attention_mask) probs = torch.softmax(outputs["logits"], dim=1) predictions = torch.argmax(probs, dim=1) return predictions, probs def count_parameters(model): """Count trainable parameters""" return sum(p.numel() for p in model.parameters() if p.requires_grad)