# peptiverse_infer.py from __future__ import annotations import csv, re, json from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple, Any, List import numpy as np import torch import torch.nn as nn import joblib import xgboost as xgb from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer # ----------------------------- # Manifest # ----------------------------- @dataclass(frozen=True) class BestRow: property_key: str best_wt: Optional[str] best_smiles: Optional[str] task_type: str # "Classifier" or "Regression" thr_wt: Optional[float] thr_smiles: Optional[float] def _clean(s: str) -> str: return (s or "").strip() def _none_if_dash(s: str) -> Optional[str]: s = _clean(s) if s in {"", "-", "—", "NA", "N/A"}: return None return s def _float_or_none(s: str) -> Optional[float]: s = _clean(s) if s in {"", "-", "—", "NA", "N/A"}: return None return float(s) def normalize_property_key(name: str) -> str: n = name.strip().lower() n = re.sub(r"\s*\(.*?\)\s*", "", n) n = n.replace("-", "_").replace(" ", "_") if "permeability" in n and "pampa" not in n and "caco" not in n: return "permeability_penetrance" if n == "binding_affinity": return "binding_affinity" if n == "halflife": return "half_life" if n == "non_fouling": return "nf" return n def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: """ Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES, Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223, """ p = Path(path) out: Dict[str, BestRow] = {} with p.open("r", newline="") as f: reader = csv.reader(f) header = None for raw in reader: if not raw or all(_clean(x) == "" for x in raw): continue while raw and _clean(raw[-1]) == "": raw = raw[:-1] if header is None: header = [h.strip() for h in raw] continue if len(raw) < len(header): raw = raw + [""] * (len(header) - len(raw)) rec = dict(zip(header, raw)) prop_raw = _clean(rec.get("Properties", "")) if not prop_raw: continue prop_key = normalize_property_key(prop_raw) row = BestRow( property_key=prop_key, best_wt=_none_if_dash(rec.get("Best_Model_WT", "")), best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")), task_type=_clean(rec.get("Type", "Classifier")), thr_wt=_float_or_none(rec.get("Threshold_WT", "")), thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), ) out[prop_key] = row return out MODEL_ALIAS = { "SVM": "svm_gpu", "SVR": "svr", "ENET": "enet_gpu", "CNN": "cnn", "MLP": "mlp", "TRANSFORMER": "transformer", "XGB": "xgb", "XGB_REG": "xgb_reg", "POOLED": "pooled", "UNPOOLED": "unpooled" } def canon_model(label: Optional[str]) -> Optional[str]: if label is None: return None k = label.strip().upper() return MODEL_ALIAS.get(k, label.strip().lower()) # ----------------------------- # Generic artifact loading # ----------------------------- def find_best_artifact(model_dir: Path) -> Path: for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]: hits = sorted(model_dir.glob(pat)) if hits: return hits[0] raise FileNotFoundError(f"No best_model artifact found in {model_dir}") def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: art = find_best_artifact(model_dir) if art.suffix == ".json": booster = xgb.Booster() print(str(art)) booster.load_model(str(art)) return "xgb", booster, art if art.suffix == ".joblib": obj = joblib.load(art) return "joblib", obj, art if art.suffix == ".pt": ckpt = torch.load(art, map_location=device, weights_only=False) return "torch_ckpt", ckpt, art raise ValueError(f"Unknown artifact type: {art}") # ----------------------------- # NN architectures # ----------------------------- class MaskedMeanPool(nn.Module): def forward(self, X, M): # X:(B,L,H), M:(B,L) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom class MLPHead(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1), ) def forward(self, X, M): z = self.pool(X, M) return self.net(z).squeeze(-1) class CNNHead(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks = [] ch = in_ch for _ in range(layers): blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): Xc = X.transpose(1, 2) # (B,H,L) Y = self.conv(Xc).transpose(1, 2) # (B,L,C) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Y * Mf).sum(dim=1) / denom return self.head(pooled).squeeze(-1) class TransformerHead(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu" ) self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) self.head = nn.Linear(d_model, 1) def forward(self, X, M): pad_mask = ~M Z = self.proj(X) Z = self.enc(Z, src_key_padding_mask=pad_mask) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Z * Mf).sum(dim=1) / denom return self.head(pooled).squeeze(-1) def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: if model_name == "mlp": return int(sd["net.0.weight"].shape[1]) if model_name == "cnn": return int(sd["conv.0.weight"].shape[1]) if model_name == "transformer": return int(sd["proj.weight"].shape[1]) raise ValueError(model_name) def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: params = ckpt["best_params"] sd = ckpt["state_dict"] in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) dropout = float(params.get("dropout", 0.1)) if model_name == "mlp": model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) elif model_name == "cnn": model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), layers=int(params["layers"]), dropout=dropout) elif model_name == "transformer": model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]), layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout) else: raise ValueError(f"Unknown NN model_name={model_name}") model.load_state_dict(sd) model.to(device) model.eval() return model # ----------------------------- # Binding affinity models # ----------------------------- def affinity_to_class(y: float) -> int: # 0=High(>=9), 1=Moderate(7-9), 2=Low(<7) if y >= 9.0: return 0 if y < 7.0: return 2 return 1 class CrossAttnPooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def forward(self, t_vec, b_vec): t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H) b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H) for L in self.layers: t_attn, _ = L["attn_tb"](t, b, b) t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) b_attn, _ = L["attn_bt"](b, t, t) b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) z = torch.cat([t[0], b[0]], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) class CrossAttnUnpooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def _masked_mean(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom def forward(self, T, Mt, B, Mb): T = self.t_proj(T) Bx = self.b_proj(B) kp_t = ~Mt kp_b = ~Mb for L in self.layers: T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) T = L["n1t"](T + T_attn) T = L["n2t"](T + L["fft"](T)) B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) Bx = L["n1b"](Bx + B_attn) Bx = L["n2b"](Bx + L["ffb"](Bx)) t_pool = self._masked_mean(T, Mt) b_pool = self._masked_mean(Bx, Mb) z = torch.cat([t_pool, b_pool], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) params = ckpt["best_params"] sd = ckpt["state_dict"] # infer Ht/Hb from projection weights Ht = int(sd["t_proj.0.weight"].shape[1]) Hb = int(sd["b_proj.0.weight"].shape[1]) common = dict( Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]), n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]), dropout=float(params["dropout"]), ) if pooled_or_unpooled == "pooled": model = CrossAttnPooled(**common) elif pooled_or_unpooled == "unpooled": model = CrossAttnUnpooled(**common) else: raise ValueError(pooled_or_unpooled) model.load_state_dict(sd) model.to(device).eval() return model # ----------------------------- # Embedding generation # ----------------------------- def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: """ Pytorch patch """ if hasattr(torch, "isin"): return torch.isin(ids, test_ids) # Fallback: compare against each special id # (B,L,1) == (1,1,K) -> (B,L,K) return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) class SMILESEmbedder: """ PeptideCLM RoFormer embeddings for SMILES. - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS - unpooled(): returns token embeddings filtered to valid tokens (specials removed), plus a 1-mask of length Li (since already filtered). """ def __init__( self, device: torch.device, vocab_path: str, splits_path: str, clm_name: str = "aaronfeller/PeptideCLM-23M-all", max_len: int = 512, use_cache: bool = True, ): self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if len(self.special_ids) else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [ getattr(tokenizer, "pad_token_id", None), getattr(tokenizer, "cls_token_id", None), getattr(tokenizer, "sep_token_id", None), getattr(tokenizer, "bos_token_id", None), getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "mask_token_id", None), ] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]: tok = self.tokenizer( smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len, ) for k in tok: tok[k] = tok[k].to(self.device) if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok @torch.no_grad() def pooled(self, smiles: str) -> torch.Tensor: s = smiles.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) vf = valid.unsqueeze(-1).float() summed = (h * vf).sum(dim=1) # (1,H) denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1) pooled = summed / denom # (1,H) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: X: (1, Li, H) float32 on device M: (1, Li) bool on device where Li excludes padding + special tokens. """ s = smiles.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) # filter valid tokens keep = valid[0] # (L,) X = h[:, keep, :] # (1,Li,H) M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M class WTEmbedder: """ ESM2 embeddings for AA sequences. - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...} - unpooled(): returns token embeddings filtered to valid tokens (specials removed), plus a 1-mask of length Li (since already filtered). """ def __init__( self, device: torch.device, esm_name: str = "facebook/esm2_t33_650M_UR50D", max_len: int = 1022, use_cache: bool = True, ): self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = EsmTokenizer.from_pretrained(esm_name) self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if len(self.special_ids) else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [ getattr(tokenizer, "pad_token_id", None), getattr(tokenizer, "cls_token_id", None), getattr(tokenizer, "sep_token_id", None), getattr(tokenizer, "bos_token_id", None), getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "mask_token_id", None), ] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]: tok = self.tokenizer( seq_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len, ) tok = {k: v.to(self.device) for k, v in tok.items()} if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok @torch.no_grad() def pooled(self, seq: str) -> torch.Tensor: s = seq.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(**tok) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) vf = valid.unsqueeze(-1).float() summed = (h * vf).sum(dim=1) # (1,H) denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1) pooled = summed / denom # (1,H) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: X: (1, Li, H) float32 on device M: (1, Li) bool on device where Li excludes padding + special tokens. """ s = seq.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) ids = tok["input_ids"] # (1,L) attn = tok["attention_mask"].bool() # (1,L) out = self.model(**tok) h = out.last_hidden_state # (1,L,H) valid = attn if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) keep = valid[0] # (L,) X = h[:, keep, :] # (1,Li,H) M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M # ----------------------------- # Predictor # ----------------------------- class PeptiVersePredictor: """ - loads best models from training_classifiers/ - computes embeddings as needed (pooled/unpooled) - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled. """ def __init__( self, manifest_path: str | Path, classifier_weight_root: str | Path, esm_name="facebook/esm2_t33_650M_UR50D", clm_name="aaronfeller/PeptideCLM-23M-all", smiles_vocab="tokenizer/new_vocab.txt", smiles_splits="tokenizer/new_splits.txt", device: Optional[str] = None, ): self.root = Path(classifier_weight_root) self.training_root = self.root / "training_classifiers" self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.manifest = read_best_manifest_csv(manifest_path) self.wt_embedder = WTEmbedder(self.device) self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, vocab_path=str(self.root / smiles_vocab), splits_path=str(self.root / smiles_splits)) self.models: Dict[Tuple[str, str], Any] = {} self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} self._load_all_best_models() def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path: """ Usual layout: training_classifiers//_/ Fallbacks: - training_classifiers/// - training_classifiers//_wt """ base = self.training_root / prop_key candidates = [ base / f"{model_name}_{mode}", base / model_name, ] if mode == "wt": candidates += [base / f"{model_name}_wt"] if mode == "smiles": candidates += [base / f"{model_name}_smiles"] for d in candidates: if d.exists(): return d raise FileNotFoundError(f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}") def _load_all_best_models(self): for prop_key, row in self.manifest.items(): for mode, label, thr in [ ("wt", row.best_wt, row.thr_wt), ("smiles", row.best_smiles, row.thr_smiles), ]: m = canon_model(label) if m is None: continue # ---- binding affinity special ---- if prop_key == "binding_affinity": # label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_* pooled_or_unpooled = m # "pooled" or "unpooled" folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc. model_dir = self.training_root / "binding_affinity" / folder art = find_best_artifact(model_dir) if art.suffix != ".pt": raise RuntimeError(f"Binding model expected best_model.pt, got {art}") model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device) self.models[(prop_key, mode)] = model self.meta[(prop_key, mode)] = { "task_type": "Regression", "threshold": None, "artifact": str(art), "model_name": pooled_or_unpooled, } continue model_dir = self._resolve_dir(prop_key, m, mode) kind, obj, art = load_artifact(model_dir, self.device) if kind in {"xgb", "joblib"}: self.models[(prop_key, mode)] = obj else: # rebuild NN architecture self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device) self.meta[(prop_key, mode)] = { "task_type": row.task_type, "threshold": thr, "artifact": str(art), "model_name": m, "kind": kind, } def _get_features_for_model(self, prop_key: str, mode: str, input_str: str): """ Returns either: - pooled np array shape (1,H) for xgb/joblib - unpooled torch tensors (X,M) for NN """ model = self.models[(prop_key, mode)] meta = self.meta[(prop_key, mode)] kind = meta.get("kind", None) model_name = meta.get("model_name", "") if prop_key == "binding_affinity": raise RuntimeError("Use predict_binding_affinity().") # If torch NN: needs unpooled if kind == "torch_ckpt": if mode == "wt": X, M = self.wt_embedder.unpooled(input_str) else: X, M = self.smiles_embedder.unpooled(input_str) return X, M # Otherwise pooled vectors for xgb/joblib if mode == "wt": v = self.wt_embedder.pooled(input_str) # (1,H) else: v = self.smiles_embedder.pooled(input_str) # (1,H) feats = v.detach().cpu().numpy().astype(np.float32) feats = np.nan_to_num(feats, nan=0.0) feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) return feats def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]: """ mode: "wt" for AA sequence input, "smiles" for SMILES input Returns dict with score + label if classifier threshold exists. """ if (prop_key, mode) not in self.models: raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.") meta = self.meta[(prop_key, mode)] model = self.models[(prop_key, mode)] task_type = meta["task_type"].lower() thr = meta.get("threshold", None) kind = meta.get("kind", None) if prop_key == "binding_affinity": raise RuntimeError("Use predict_binding_affinity().") # NN path (logits / regression) if kind == "torch_ckpt": X, M = self._get_features_for_model(prop_key, mode, input_str) with torch.no_grad(): y = model(X, M).squeeze().float().cpu().item() if task_type == "classifier": prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit) out = {"property": prop_key, "mode": mode, "score": prob} if thr is not None: out["label"] = int(prob >= float(thr)) out["threshold"] = float(thr) return out else: return {"property": prop_key, "mode": mode, "score": float(y)} # xgb path if kind == "xgb": feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H) dmat = xgb.DMatrix(feats) pred = float(model.predict(dmat)[0]) out = {"property": prop_key, "mode": mode, "score": pred} if task_type == "classifier" and thr is not None: out["label"] = int(pred >= float(thr)) out["threshold"] = float(thr) return out # joblib path (svm/enet/svr) if kind == "joblib": feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H) # classifier vs regressor behavior differs by estimator if task_type == "classifier": if hasattr(model, "predict_proba"): pred = float(model.predict_proba(feats)[:, 1][0]) else: if hasattr(model, "decision_function"): logit = float(model.decision_function(feats)[0]) pred = float(1.0 / (1.0 + np.exp(-logit))) else: pred = float(model.predict(feats)[0]) out = {"property": prop_key, "mode": mode, "score": pred} if thr is not None: out["label"] = int(pred >= float(thr)) out["threshold"] = float(thr) return out else: pred = float(model.predict(feats)[0]) return {"property": prop_key, "mode": mode, "score": pred} raise RuntimeError(f"Unknown model kind={kind}") def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]: """ mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled) "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled) """ prop_key = "binding_affinity" if (prop_key, mode) not in self.models: raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).") model = self.models[(prop_key, mode)] pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled # target is always WT sequence (ESM) if pooled_or_unpooled == "pooled": t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht) if mode == "wt": b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb) else: b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb) with torch.no_grad(): reg, logits = model(t_vec, b_vec) affinity = float(reg.squeeze().cpu().item()) cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) cls_thr = affinity_to_class(affinity) else: T, Mt = self.wt_embedder.unpooled(target_seq) if mode == "wt": B, Mb = self.wt_embedder.unpooled(binder_str) else: B, Mb = self.smiles_embedder.unpooled(binder_str) with torch.no_grad(): reg, logits = model(T, Mt, B, Mb) affinity = float(reg.squeeze().cpu().item()) cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) cls_thr = affinity_to_class(affinity) names = {0: "High (≥9)", 1: "Moderate (7–9)", 2: "Low (<7)"} return { "property": "binding_affinity", "mode": mode, "affinity": affinity, "class_by_threshold": names[cls_thr], "class_by_logits": names[cls_logit], "binding_model": pooled_or_unpooled, } # ----------------------------- # Minimal usage # ----------------------------- if __name__ == "__main__": # Example: predictor = PeptiVersePredictor( manifest_path="best_models.txt", classifier_weight_root="/vast/projects/pranam/lab/yz927/projects/Classifier_Weight" ) print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ")) print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="...")) # Test Embedding # """ device = torch.device("cuda:0") wt = WTEmbedder(device) sm = SMILESEmbedder(device, vocab_path="/home/enol/PeptideGym/Data_split/tokenizer/new_vocab.txt", splits_path="/home/enol/PeptideGym/Data_split/tokenizer/new_splits.txt" ) p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280) X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li) p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles) X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li) """