|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import csv, re, json |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Tuple, Any, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import joblib |
|
|
import xgboost as xgb |
|
|
|
|
|
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM |
|
|
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class BestRow: |
|
|
property_key: str |
|
|
best_wt: Optional[str] |
|
|
best_smiles: Optional[str] |
|
|
task_type: str |
|
|
thr_wt: Optional[float] |
|
|
thr_smiles: Optional[float] |
|
|
|
|
|
|
|
|
def _clean(s: str) -> str: |
|
|
return (s or "").strip() |
|
|
|
|
|
def _none_if_dash(s: str) -> Optional[str]: |
|
|
s = _clean(s) |
|
|
if s in {"", "-", "—", "NA", "N/A"}: |
|
|
return None |
|
|
return s |
|
|
|
|
|
def _float_or_none(s: str) -> Optional[float]: |
|
|
s = _clean(s) |
|
|
if s in {"", "-", "—", "NA", "N/A"}: |
|
|
return None |
|
|
return float(s) |
|
|
|
|
|
def normalize_property_key(name: str) -> str: |
|
|
n = name.strip().lower() |
|
|
n = re.sub(r"\s*\(.*?\)\s*", "", n) |
|
|
n = n.replace("-", "_").replace(" ", "_") |
|
|
if "permeability" in n and "pampa" not in n and "caco" not in n: |
|
|
return "permeability_penetrance" |
|
|
if n == "binding_affinity": |
|
|
return "binding_affinity" |
|
|
if n == "halflife": |
|
|
return "half_life" |
|
|
if n == "non_fouling": |
|
|
return "nf" |
|
|
return n |
|
|
|
|
|
def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: |
|
|
""" |
|
|
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES, |
|
|
Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223, |
|
|
""" |
|
|
p = Path(path) |
|
|
out: Dict[str, BestRow] = {} |
|
|
|
|
|
with p.open("r", newline="") as f: |
|
|
reader = csv.reader(f) |
|
|
header = None |
|
|
for raw in reader: |
|
|
if not raw or all(_clean(x) == "" for x in raw): |
|
|
continue |
|
|
while raw and _clean(raw[-1]) == "": |
|
|
raw = raw[:-1] |
|
|
|
|
|
if header is None: |
|
|
header = [h.strip() for h in raw] |
|
|
continue |
|
|
|
|
|
if len(raw) < len(header): |
|
|
raw = raw + [""] * (len(header) - len(raw)) |
|
|
rec = dict(zip(header, raw)) |
|
|
|
|
|
prop_raw = _clean(rec.get("Properties", "")) |
|
|
if not prop_raw: |
|
|
continue |
|
|
prop_key = normalize_property_key(prop_raw) |
|
|
|
|
|
row = BestRow( |
|
|
property_key=prop_key, |
|
|
best_wt=_none_if_dash(rec.get("Best_Model_WT", "")), |
|
|
best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")), |
|
|
task_type=_clean(rec.get("Type", "Classifier")), |
|
|
thr_wt=_float_or_none(rec.get("Threshold_WT", "")), |
|
|
thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), |
|
|
) |
|
|
out[prop_key] = row |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
MODEL_ALIAS = { |
|
|
"SVM": "svm_gpu", |
|
|
"SVR": "svr", |
|
|
"ENET": "enet_gpu", |
|
|
"CNN": "cnn", |
|
|
"MLP": "mlp", |
|
|
"TRANSFORMER": "transformer", |
|
|
"XGB": "xgb", |
|
|
"XGB_REG": "xgb_reg", |
|
|
"POOLED": "pooled", |
|
|
"UNPOOLED": "unpooled" |
|
|
} |
|
|
def canon_model(label: Optional[str]) -> Optional[str]: |
|
|
if label is None: |
|
|
return None |
|
|
k = label.strip().upper() |
|
|
return MODEL_ALIAS.get(k, label.strip().lower()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_best_artifact(model_dir: Path) -> Path: |
|
|
for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]: |
|
|
hits = sorted(model_dir.glob(pat)) |
|
|
if hits: |
|
|
return hits[0] |
|
|
raise FileNotFoundError(f"No best_model artifact found in {model_dir}") |
|
|
|
|
|
def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: |
|
|
art = find_best_artifact(model_dir) |
|
|
|
|
|
if art.suffix == ".json": |
|
|
booster = xgb.Booster() |
|
|
print(str(art)) |
|
|
booster.load_model(str(art)) |
|
|
return "xgb", booster, art |
|
|
|
|
|
if art.suffix == ".joblib": |
|
|
obj = joblib.load(art) |
|
|
return "joblib", obj, art |
|
|
|
|
|
if art.suffix == ".pt": |
|
|
ckpt = torch.load(art, map_location=device, weights_only=False) |
|
|
return "torch_ckpt", ckpt, art |
|
|
|
|
|
raise ValueError(f"Unknown artifact type: {art}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskedMeanPool(nn.Module): |
|
|
def forward(self, X, M): |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
return (X * Mf).sum(dim=1) / denom |
|
|
|
|
|
class MLPHead(nn.Module): |
|
|
def __init__(self, in_dim, hidden=512, dropout=0.1): |
|
|
super().__init__() |
|
|
self.pool = MaskedMeanPool() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(in_dim, hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden, 1), |
|
|
) |
|
|
def forward(self, X, M): |
|
|
z = self.pool(X, M) |
|
|
return self.net(z).squeeze(-1) |
|
|
|
|
|
class CNNHead(nn.Module): |
|
|
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): |
|
|
super().__init__() |
|
|
blocks = [] |
|
|
ch = in_ch |
|
|
for _ in range(layers): |
|
|
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout)] |
|
|
ch = c |
|
|
self.conv = nn.Sequential(*blocks) |
|
|
self.head = nn.Linear(c, 1) |
|
|
|
|
|
def forward(self, X, M): |
|
|
Xc = X.transpose(1, 2) |
|
|
Y = self.conv(Xc).transpose(1, 2) |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
pooled = (Y * Mf).sum(dim=1) / denom |
|
|
return self.head(pooled).squeeze(-1) |
|
|
|
|
|
class TransformerHead(nn.Module): |
|
|
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): |
|
|
super().__init__() |
|
|
self.proj = nn.Linear(in_dim, d_model) |
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, dim_feedforward=ff, |
|
|
dropout=dropout, batch_first=True, activation="gelu" |
|
|
) |
|
|
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) |
|
|
self.head = nn.Linear(d_model, 1) |
|
|
|
|
|
def forward(self, X, M): |
|
|
pad_mask = ~M |
|
|
Z = self.proj(X) |
|
|
Z = self.enc(Z, src_key_padding_mask=pad_mask) |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
pooled = (Z * Mf).sum(dim=1) / denom |
|
|
return self.head(pooled).squeeze(-1) |
|
|
|
|
|
def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: |
|
|
if model_name == "mlp": |
|
|
return int(sd["net.0.weight"].shape[1]) |
|
|
if model_name == "cnn": |
|
|
return int(sd["conv.0.weight"].shape[1]) |
|
|
if model_name == "transformer": |
|
|
return int(sd["proj.weight"].shape[1]) |
|
|
raise ValueError(model_name) |
|
|
|
|
|
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: |
|
|
params = ckpt["best_params"] |
|
|
sd = ckpt["state_dict"] |
|
|
in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) |
|
|
dropout = float(params.get("dropout", 0.1)) |
|
|
|
|
|
if model_name == "mlp": |
|
|
model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) |
|
|
elif model_name == "cnn": |
|
|
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), |
|
|
layers=int(params["layers"]), dropout=dropout) |
|
|
elif model_name == "transformer": |
|
|
model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]), |
|
|
layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout) |
|
|
else: |
|
|
raise ValueError(f"Unknown NN model_name={model_name}") |
|
|
|
|
|
model.load_state_dict(sd) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def affinity_to_class(y: float) -> int: |
|
|
|
|
|
if y >= 9.0: return 0 |
|
|
if y < 7.0: return 2 |
|
|
return 1 |
|
|
|
|
|
class CrossAttnPooled(nn.Module): |
|
|
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
|
|
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleDict({ |
|
|
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
|
|
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
|
|
"n1t": nn.LayerNorm(hidden), |
|
|
"n2t": nn.LayerNorm(hidden), |
|
|
"n1b": nn.LayerNorm(hidden), |
|
|
"n2b": nn.LayerNorm(hidden), |
|
|
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
})) |
|
|
|
|
|
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
|
|
self.reg = nn.Linear(hidden, 1) |
|
|
self.cls = nn.Linear(hidden, 3) |
|
|
|
|
|
def forward(self, t_vec, b_vec): |
|
|
t = self.t_proj(t_vec).unsqueeze(0) |
|
|
b = self.b_proj(b_vec).unsqueeze(0) |
|
|
for L in self.layers: |
|
|
t_attn, _ = L["attn_tb"](t, b, b) |
|
|
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
|
|
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
|
|
|
|
|
b_attn, _ = L["attn_bt"](b, t, t) |
|
|
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
|
|
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
|
|
|
|
|
z = torch.cat([t[0], b[0]], dim=-1) |
|
|
h = self.shared(z) |
|
|
return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
|
class CrossAttnUnpooled(nn.Module): |
|
|
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
|
|
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleDict({ |
|
|
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
|
|
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
|
|
"n1t": nn.LayerNorm(hidden), |
|
|
"n2t": nn.LayerNorm(hidden), |
|
|
"n1b": nn.LayerNorm(hidden), |
|
|
"n2b": nn.LayerNorm(hidden), |
|
|
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
})) |
|
|
|
|
|
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
|
|
self.reg = nn.Linear(hidden, 1) |
|
|
self.cls = nn.Linear(hidden, 3) |
|
|
|
|
|
def _masked_mean(self, X, M): |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
return (X * Mf).sum(dim=1) / denom |
|
|
|
|
|
def forward(self, T, Mt, B, Mb): |
|
|
T = self.t_proj(T) |
|
|
Bx = self.b_proj(B) |
|
|
kp_t = ~Mt |
|
|
kp_b = ~Mb |
|
|
|
|
|
for L in self.layers: |
|
|
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
|
|
T = L["n1t"](T + T_attn) |
|
|
T = L["n2t"](T + L["fft"](T)) |
|
|
|
|
|
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
|
|
Bx = L["n1b"](Bx + B_attn) |
|
|
Bx = L["n2b"](Bx + L["ffb"](Bx)) |
|
|
|
|
|
t_pool = self._masked_mean(T, Mt) |
|
|
b_pool = self._masked_mean(Bx, Mb) |
|
|
z = torch.cat([t_pool, b_pool], dim=-1) |
|
|
h = self.shared(z) |
|
|
return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
|
def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: |
|
|
ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) |
|
|
params = ckpt["best_params"] |
|
|
sd = ckpt["state_dict"] |
|
|
|
|
|
|
|
|
Ht = int(sd["t_proj.0.weight"].shape[1]) |
|
|
Hb = int(sd["b_proj.0.weight"].shape[1]) |
|
|
|
|
|
common = dict( |
|
|
Ht=Ht, Hb=Hb, |
|
|
hidden=int(params["hidden_dim"]), |
|
|
n_heads=int(params["n_heads"]), |
|
|
n_layers=int(params["n_layers"]), |
|
|
dropout=float(params["dropout"]), |
|
|
) |
|
|
|
|
|
if pooled_or_unpooled == "pooled": |
|
|
model = CrossAttnPooled(**common) |
|
|
elif pooled_or_unpooled == "unpooled": |
|
|
model = CrossAttnUnpooled(**common) |
|
|
else: |
|
|
raise ValueError(pooled_or_unpooled) |
|
|
|
|
|
model.load_state_dict(sd) |
|
|
model.to(device).eval() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Pytorch patch |
|
|
""" |
|
|
if hasattr(torch, "isin"): |
|
|
return torch.isin(ids, test_ids) |
|
|
|
|
|
|
|
|
return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) |
|
|
|
|
|
class SMILESEmbedder: |
|
|
""" |
|
|
PeptideCLM RoFormer embeddings for SMILES. |
|
|
- pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS |
|
|
- unpooled(): returns token embeddings filtered to valid tokens (specials removed), |
|
|
plus a 1-mask of length Li (since already filtered). |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
device: torch.device, |
|
|
vocab_path: str, |
|
|
splits_path: str, |
|
|
clm_name: str = "aaronfeller/PeptideCLM-23M-all", |
|
|
max_len: int = 512, |
|
|
use_cache: bool = True, |
|
|
): |
|
|
self.device = device |
|
|
self.max_len = max_len |
|
|
self.use_cache = use_cache |
|
|
|
|
|
self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) |
|
|
self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() |
|
|
|
|
|
self.special_ids = self._get_special_ids(self.tokenizer) |
|
|
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) |
|
|
if len(self.special_ids) else None) |
|
|
|
|
|
self._cache_pooled: Dict[str, torch.Tensor] = {} |
|
|
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
|
|
|
|
|
@staticmethod |
|
|
def _get_special_ids(tokenizer) -> List[int]: |
|
|
cand = [ |
|
|
getattr(tokenizer, "pad_token_id", None), |
|
|
getattr(tokenizer, "cls_token_id", None), |
|
|
getattr(tokenizer, "sep_token_id", None), |
|
|
getattr(tokenizer, "bos_token_id", None), |
|
|
getattr(tokenizer, "eos_token_id", None), |
|
|
getattr(tokenizer, "mask_token_id", None), |
|
|
] |
|
|
return sorted({int(x) for x in cand if x is not None}) |
|
|
|
|
|
def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]: |
|
|
tok = self.tokenizer( |
|
|
smiles_list, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=self.max_len, |
|
|
) |
|
|
for k in tok: |
|
|
tok[k] = tok[k].to(self.device) |
|
|
if "attention_mask" not in tok: |
|
|
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) |
|
|
return tok |
|
|
|
|
|
@torch.no_grad() |
|
|
def pooled(self, smiles: str) -> torch.Tensor: |
|
|
s = smiles.strip() |
|
|
if self.use_cache and s in self._cache_pooled: |
|
|
return self._cache_pooled[s] |
|
|
|
|
|
tok = self._tokenize([s]) |
|
|
ids = tok["input_ids"] |
|
|
attn = tok["attention_mask"].bool() |
|
|
|
|
|
out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) |
|
|
h = out.last_hidden_state |
|
|
|
|
|
valid = attn |
|
|
if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
|
|
valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
|
|
|
vf = valid.unsqueeze(-1).float() |
|
|
summed = (h * vf).sum(dim=1) |
|
|
denom = vf.sum(dim=1).clamp(min=1e-9) |
|
|
pooled = summed / denom |
|
|
|
|
|
if self.use_cache: |
|
|
self._cache_pooled[s] = pooled |
|
|
return pooled |
|
|
|
|
|
@torch.no_grad() |
|
|
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Returns: |
|
|
X: (1, Li, H) float32 on device |
|
|
M: (1, Li) bool on device |
|
|
where Li excludes padding + special tokens. |
|
|
""" |
|
|
s = smiles.strip() |
|
|
if self.use_cache and s in self._cache_unpooled: |
|
|
return self._cache_unpooled[s] |
|
|
|
|
|
tok = self._tokenize([s]) |
|
|
ids = tok["input_ids"] |
|
|
attn = tok["attention_mask"].bool() |
|
|
|
|
|
out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) |
|
|
h = out.last_hidden_state |
|
|
|
|
|
valid = attn |
|
|
if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
|
|
valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
|
|
|
|
|
|
keep = valid[0] |
|
|
X = h[:, keep, :] |
|
|
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) |
|
|
|
|
|
if self.use_cache: |
|
|
self._cache_unpooled[s] = (X, M) |
|
|
return X, M |
|
|
|
|
|
|
|
|
class WTEmbedder: |
|
|
""" |
|
|
ESM2 embeddings for AA sequences. |
|
|
- pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...} |
|
|
- unpooled(): returns token embeddings filtered to valid tokens (specials removed), |
|
|
plus a 1-mask of length Li (since already filtered). |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
device: torch.device, |
|
|
esm_name: str = "facebook/esm2_t33_650M_UR50D", |
|
|
max_len: int = 1022, |
|
|
use_cache: bool = True, |
|
|
): |
|
|
self.device = device |
|
|
self.max_len = max_len |
|
|
self.use_cache = use_cache |
|
|
|
|
|
self.tokenizer = EsmTokenizer.from_pretrained(esm_name) |
|
|
self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() |
|
|
|
|
|
self.special_ids = self._get_special_ids(self.tokenizer) |
|
|
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) |
|
|
if len(self.special_ids) else None) |
|
|
|
|
|
self._cache_pooled: Dict[str, torch.Tensor] = {} |
|
|
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
|
|
|
|
|
@staticmethod |
|
|
def _get_special_ids(tokenizer) -> List[int]: |
|
|
cand = [ |
|
|
getattr(tokenizer, "pad_token_id", None), |
|
|
getattr(tokenizer, "cls_token_id", None), |
|
|
getattr(tokenizer, "sep_token_id", None), |
|
|
getattr(tokenizer, "bos_token_id", None), |
|
|
getattr(tokenizer, "eos_token_id", None), |
|
|
getattr(tokenizer, "mask_token_id", None), |
|
|
] |
|
|
return sorted({int(x) for x in cand if x is not None}) |
|
|
|
|
|
def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]: |
|
|
tok = self.tokenizer( |
|
|
seq_list, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=self.max_len, |
|
|
) |
|
|
tok = {k: v.to(self.device) for k, v in tok.items()} |
|
|
if "attention_mask" not in tok: |
|
|
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) |
|
|
return tok |
|
|
|
|
|
@torch.no_grad() |
|
|
def pooled(self, seq: str) -> torch.Tensor: |
|
|
s = seq.strip() |
|
|
if self.use_cache and s in self._cache_pooled: |
|
|
return self._cache_pooled[s] |
|
|
|
|
|
tok = self._tokenize([s]) |
|
|
ids = tok["input_ids"] |
|
|
attn = tok["attention_mask"].bool() |
|
|
|
|
|
out = self.model(**tok) |
|
|
h = out.last_hidden_state |
|
|
|
|
|
valid = attn |
|
|
if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
|
|
valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
|
|
|
vf = valid.unsqueeze(-1).float() |
|
|
summed = (h * vf).sum(dim=1) |
|
|
denom = vf.sum(dim=1).clamp(min=1e-9) |
|
|
pooled = summed / denom |
|
|
|
|
|
if self.use_cache: |
|
|
self._cache_pooled[s] = pooled |
|
|
return pooled |
|
|
|
|
|
@torch.no_grad() |
|
|
def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Returns: |
|
|
X: (1, Li, H) float32 on device |
|
|
M: (1, Li) bool on device |
|
|
where Li excludes padding + special tokens. |
|
|
""" |
|
|
s = seq.strip() |
|
|
if self.use_cache and s in self._cache_unpooled: |
|
|
return self._cache_unpooled[s] |
|
|
|
|
|
tok = self._tokenize([s]) |
|
|
ids = tok["input_ids"] |
|
|
attn = tok["attention_mask"].bool() |
|
|
|
|
|
out = self.model(**tok) |
|
|
h = out.last_hidden_state |
|
|
|
|
|
valid = attn |
|
|
if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
|
|
valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
|
|
|
keep = valid[0] |
|
|
X = h[:, keep, :] |
|
|
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) |
|
|
|
|
|
if self.use_cache: |
|
|
self._cache_unpooled[s] = (X, M) |
|
|
return X, M |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PeptiVersePredictor: |
|
|
""" |
|
|
- loads best models from training_classifiers/ |
|
|
- computes embeddings as needed (pooled/unpooled) |
|
|
- supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
manifest_path: str | Path, |
|
|
classifier_weight_root: str | Path, |
|
|
esm_name="facebook/esm2_t33_650M_UR50D", |
|
|
clm_name="aaronfeller/PeptideCLM-23M-all", |
|
|
smiles_vocab="tokenizer/new_vocab.txt", |
|
|
smiles_splits="tokenizer/new_splits.txt", |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
self.root = Path(classifier_weight_root) |
|
|
self.training_root = self.root / "training_classifiers" |
|
|
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
|
|
|
self.manifest = read_best_manifest_csv(manifest_path) |
|
|
|
|
|
self.wt_embedder = WTEmbedder(self.device) |
|
|
self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, |
|
|
vocab_path=str(self.root / smiles_vocab), |
|
|
splits_path=str(self.root / smiles_splits)) |
|
|
|
|
|
self.models: Dict[Tuple[str, str], Any] = {} |
|
|
self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} |
|
|
|
|
|
self._load_all_best_models() |
|
|
|
|
|
def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path: |
|
|
""" |
|
|
Usual layout: training_classifiers/<prop>/<model>_<mode>/ |
|
|
Fallbacks: |
|
|
- training_classifiers/<prop>/<model>/ |
|
|
- training_classifiers/<prop>/<model>_wt |
|
|
""" |
|
|
base = self.training_root / prop_key |
|
|
candidates = [ |
|
|
base / f"{model_name}_{mode}", |
|
|
base / model_name, |
|
|
] |
|
|
if mode == "wt": |
|
|
candidates += [base / f"{model_name}_wt"] |
|
|
if mode == "smiles": |
|
|
candidates += [base / f"{model_name}_smiles"] |
|
|
|
|
|
for d in candidates: |
|
|
if d.exists(): |
|
|
return d |
|
|
raise FileNotFoundError(f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}") |
|
|
|
|
|
def _load_all_best_models(self): |
|
|
for prop_key, row in self.manifest.items(): |
|
|
for mode, label, thr in [ |
|
|
("wt", row.best_wt, row.thr_wt), |
|
|
("smiles", row.best_smiles, row.thr_smiles), |
|
|
]: |
|
|
m = canon_model(label) |
|
|
if m is None: |
|
|
continue |
|
|
|
|
|
|
|
|
if prop_key == "binding_affinity": |
|
|
|
|
|
pooled_or_unpooled = m |
|
|
folder = f"wt_{mode}_{pooled_or_unpooled}" |
|
|
model_dir = self.training_root / "binding_affinity" / folder |
|
|
art = find_best_artifact(model_dir) |
|
|
if art.suffix != ".pt": |
|
|
raise RuntimeError(f"Binding model expected best_model.pt, got {art}") |
|
|
model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device) |
|
|
self.models[(prop_key, mode)] = model |
|
|
self.meta[(prop_key, mode)] = { |
|
|
"task_type": "Regression", |
|
|
"threshold": None, |
|
|
"artifact": str(art), |
|
|
"model_name": pooled_or_unpooled, |
|
|
} |
|
|
continue |
|
|
|
|
|
model_dir = self._resolve_dir(prop_key, m, mode) |
|
|
kind, obj, art = load_artifact(model_dir, self.device) |
|
|
|
|
|
if kind in {"xgb", "joblib"}: |
|
|
self.models[(prop_key, mode)] = obj |
|
|
else: |
|
|
|
|
|
self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device) |
|
|
|
|
|
self.meta[(prop_key, mode)] = { |
|
|
"task_type": row.task_type, |
|
|
"threshold": thr, |
|
|
"artifact": str(art), |
|
|
"model_name": m, |
|
|
"kind": kind, |
|
|
} |
|
|
|
|
|
def _get_features_for_model(self, prop_key: str, mode: str, input_str: str): |
|
|
""" |
|
|
Returns either: |
|
|
- pooled np array shape (1,H) for xgb/joblib |
|
|
- unpooled torch tensors (X,M) for NN |
|
|
""" |
|
|
model = self.models[(prop_key, mode)] |
|
|
meta = self.meta[(prop_key, mode)] |
|
|
kind = meta.get("kind", None) |
|
|
model_name = meta.get("model_name", "") |
|
|
|
|
|
if prop_key == "binding_affinity": |
|
|
raise RuntimeError("Use predict_binding_affinity().") |
|
|
|
|
|
|
|
|
if kind == "torch_ckpt": |
|
|
if mode == "wt": |
|
|
X, M = self.wt_embedder.unpooled(input_str) |
|
|
else: |
|
|
X, M = self.smiles_embedder.unpooled(input_str) |
|
|
return X, M |
|
|
|
|
|
|
|
|
if mode == "wt": |
|
|
v = self.wt_embedder.pooled(input_str) |
|
|
else: |
|
|
v = self.smiles_embedder.pooled(input_str) |
|
|
feats = v.detach().cpu().numpy().astype(np.float32) |
|
|
feats = np.nan_to_num(feats, nan=0.0) |
|
|
feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) |
|
|
return feats |
|
|
|
|
|
def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]: |
|
|
""" |
|
|
mode: "wt" for AA sequence input, "smiles" for SMILES input |
|
|
Returns dict with score + label if classifier threshold exists. |
|
|
""" |
|
|
if (prop_key, mode) not in self.models: |
|
|
raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.") |
|
|
|
|
|
meta = self.meta[(prop_key, mode)] |
|
|
model = self.models[(prop_key, mode)] |
|
|
task_type = meta["task_type"].lower() |
|
|
thr = meta.get("threshold", None) |
|
|
kind = meta.get("kind", None) |
|
|
|
|
|
if prop_key == "binding_affinity": |
|
|
raise RuntimeError("Use predict_binding_affinity().") |
|
|
|
|
|
|
|
|
if kind == "torch_ckpt": |
|
|
X, M = self._get_features_for_model(prop_key, mode, input_str) |
|
|
with torch.no_grad(): |
|
|
y = model(X, M).squeeze().float().cpu().item() |
|
|
if task_type == "classifier": |
|
|
prob = float(1.0 / (1.0 + np.exp(-y))) |
|
|
out = {"property": prop_key, "mode": mode, "score": prob} |
|
|
if thr is not None: |
|
|
out["label"] = int(prob >= float(thr)) |
|
|
out["threshold"] = float(thr) |
|
|
return out |
|
|
else: |
|
|
return {"property": prop_key, "mode": mode, "score": float(y)} |
|
|
|
|
|
|
|
|
if kind == "xgb": |
|
|
feats = self._get_features_for_model(prop_key, mode, input_str) |
|
|
dmat = xgb.DMatrix(feats) |
|
|
pred = float(model.predict(dmat)[0]) |
|
|
out = {"property": prop_key, "mode": mode, "score": pred} |
|
|
if task_type == "classifier" and thr is not None: |
|
|
out["label"] = int(pred >= float(thr)) |
|
|
out["threshold"] = float(thr) |
|
|
return out |
|
|
|
|
|
|
|
|
if kind == "joblib": |
|
|
feats = self._get_features_for_model(prop_key, mode, input_str) |
|
|
|
|
|
if task_type == "classifier": |
|
|
if hasattr(model, "predict_proba"): |
|
|
pred = float(model.predict_proba(feats)[:, 1][0]) |
|
|
else: |
|
|
if hasattr(model, "decision_function"): |
|
|
logit = float(model.decision_function(feats)[0]) |
|
|
pred = float(1.0 / (1.0 + np.exp(-logit))) |
|
|
else: |
|
|
pred = float(model.predict(feats)[0]) |
|
|
out = {"property": prop_key, "mode": mode, "score": pred} |
|
|
if thr is not None: |
|
|
out["label"] = int(pred >= float(thr)) |
|
|
out["threshold"] = float(thr) |
|
|
return out |
|
|
else: |
|
|
pred = float(model.predict(feats)[0]) |
|
|
return {"property": prop_key, "mode": mode, "score": pred} |
|
|
|
|
|
raise RuntimeError(f"Unknown model kind={kind}") |
|
|
|
|
|
def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]: |
|
|
""" |
|
|
mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled) |
|
|
"smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled) |
|
|
""" |
|
|
prop_key = "binding_affinity" |
|
|
if (prop_key, mode) not in self.models: |
|
|
raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).") |
|
|
|
|
|
model = self.models[(prop_key, mode)] |
|
|
pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] |
|
|
|
|
|
|
|
|
if pooled_or_unpooled == "pooled": |
|
|
t_vec = self.wt_embedder.pooled(target_seq) |
|
|
if mode == "wt": |
|
|
b_vec = self.wt_embedder.pooled(binder_str) |
|
|
else: |
|
|
b_vec = self.smiles_embedder.pooled(binder_str) |
|
|
with torch.no_grad(): |
|
|
reg, logits = model(t_vec, b_vec) |
|
|
affinity = float(reg.squeeze().cpu().item()) |
|
|
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) |
|
|
cls_thr = affinity_to_class(affinity) |
|
|
else: |
|
|
T, Mt = self.wt_embedder.unpooled(target_seq) |
|
|
if mode == "wt": |
|
|
B, Mb = self.wt_embedder.unpooled(binder_str) |
|
|
else: |
|
|
B, Mb = self.smiles_embedder.unpooled(binder_str) |
|
|
with torch.no_grad(): |
|
|
reg, logits = model(T, Mt, B, Mb) |
|
|
affinity = float(reg.squeeze().cpu().item()) |
|
|
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) |
|
|
cls_thr = affinity_to_class(affinity) |
|
|
|
|
|
names = {0: "High (≥9)", 1: "Moderate (7–9)", 2: "Low (<7)"} |
|
|
return { |
|
|
"property": "binding_affinity", |
|
|
"mode": mode, |
|
|
"affinity": affinity, |
|
|
"class_by_threshold": names[cls_thr], |
|
|
"class_by_logits": names[cls_logit], |
|
|
"binding_model": pooled_or_unpooled, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
predictor = PeptiVersePredictor( |
|
|
manifest_path="best_models.txt", |
|
|
classifier_weight_root="/vast/projects/pranam/lab/yz927/projects/Classifier_Weight" |
|
|
) |
|
|
print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ")) |
|
|
print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="...")) |
|
|
|
|
|
|
|
|
""" |
|
|
device = torch.device("cuda:0") |
|
|
|
|
|
wt = WTEmbedder(device) |
|
|
sm = SMILESEmbedder(device, |
|
|
vocab_path="/home/enol/PeptideGym/Data_split/tokenizer/new_vocab.txt", |
|
|
splits_path="/home/enol/PeptideGym/Data_split/tokenizer/new_splits.txt" |
|
|
) |
|
|
|
|
|
p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280) |
|
|
X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li) |
|
|
|
|
|
p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles) |
|
|
X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li) |
|
|
""" |
|
|
|