File size: 1,473 Bytes
117e99b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import torch
from transformers.utils import ModelOutput
from transformers import PreTrainedModel

from .configuration_seqscreen import SeqScreenConfig

@dataclass
class SeqScreenModelOutput(ModelOutput):
  prot_rep: torch.FloatTensor = None
  mol_rep: torch.FloatTensor = None
  similarity: torch.FloatTensor = None

class ProjectionLayer(nn.Module):
  def __init__(self, in_dim, out_dim, dropout):
    super().__init__()
    self.projection = nn.Sequential(
      nn.Linear(in_dim, out_dim),
      nn.LayerNorm(out_dim),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(out_dim, out_dim)
    )

  def forward(self, x):
    x = self.projection(x)
    return F.normalize(x, dim=-1)


class SeqScreenModel(PreTrainedModel):
  config_class = SeqScreenConfig
  base_model_prefix = "seqscreen"

  def __init__(self, config: SeqScreenConfig):
    super().__init__(config)

    self.proj_prot = ProjectionLayer(config.prot_dim, config.proj_dim, dropout=config.dropout)
    self.proj_mol = ProjectionLayer(config.mol_dim, config.proj_dim, dropout=config.dropout)

    self.post_init()

  def forward(self, prot: torch.Tensor, mol: torch.Tensor):
    prot_rep = self.proj_prot(prot)
    mol_rep = self.proj_mol(mol)
    similarity = prot_rep @ mol_rep.T

    return SeqScreenModelOutput(
      prot_rep=prot_rep,
      mol_rep=mol_rep,
      similarity=similarity
    )