import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass import torch from transformers.utils import ModelOutput from transformers import PreTrainedModel from .configuration_seqscreen import SeqScreenConfig @dataclass class SeqScreenModelOutput(ModelOutput): prot_rep: torch.FloatTensor = None mol_rep: torch.FloatTensor = None similarity: torch.FloatTensor = None class ProjectionLayer(nn.Module): def __init__(self, in_dim, out_dim, dropout): super().__init__() self.projection = nn.Sequential( nn.Linear(in_dim, out_dim), nn.LayerNorm(out_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_dim, out_dim) ) def forward(self, x): x = self.projection(x) return F.normalize(x, dim=-1) class SeqScreenModel(PreTrainedModel): config_class = SeqScreenConfig base_model_prefix = "seqscreen" def __init__(self, config: SeqScreenConfig): super().__init__(config) self.proj_prot = ProjectionLayer(config.prot_dim, config.proj_dim, dropout=config.dropout) self.proj_mol = ProjectionLayer(config.mol_dim, config.proj_dim, dropout=config.dropout) self.post_init() def forward(self, prot: torch.Tensor, mol: torch.Tensor): prot_rep = self.proj_prot(prot) mol_rep = self.proj_mol(mol) similarity = prot_rep @ mol_rep.T return SeqScreenModelOutput( prot_rep=prot_rep, mol_rep=mol_rep, similarity=similarity )