Joblib
PeptiVerse / training_classifiers /binding_training.py
ynuozhang
update code
baf3373
import os, json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import optuna
from datasets import load_from_disk, DatasetDict
from scipy.stats import spearmanr
from lightning.pytorch import seed_everything
seed_everything(1986)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
rho = spearmanr(y_true, y_pred).correlation
if rho is None or np.isnan(rho):
return 0.0
return float(rho)
# -----------------------------
# Affinity class thresholds (final spec)
# High >= 9 ; Moderate 7-9 ; Low < 7
# 0=High, 1=Moderate, 2=Low
# -----------------------------
def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
high = y >= 9.0
low = y < 7.0
mid = ~(high | low)
cls = torch.zeros_like(y, dtype=torch.long)
cls[mid] = 1
cls[low] = 2
return cls
# -----------------------------
# Load paired DatasetDict
# -----------------------------
def load_split_paired(path: str):
dd = load_from_disk(path)
if not isinstance(dd, DatasetDict):
raise ValueError(f"Expected DatasetDict at {path}")
if "train" not in dd or "val" not in dd:
raise ValueError(f"DatasetDict missing train/val at {path}")
return dd["train"], dd["val"]
# -----------------------------
# Collate: pooled paired
# -----------------------------
def collate_pair_pooled(batch):
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
return Pt, Pb, y
# -----------------------------
# Collate: unpooled paired
# -----------------------------
def collate_pair_unpooled(batch):
B = len(batch)
Ht = len(batch[0]["target_embedding"][0])
Hb = len(batch[0]["binder_embedding"][0])
Lt_max = max(int(x["target_length"]) for x in batch)
Lb_max = max(int(x["binder_length"]) for x in batch)
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
for i, x in enumerate(batch):
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
lt, lb = t.shape[0], b.shape[0]
Pt[i, :lt] = t
Pb[i, :lb] = b
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
return Pt, Mt, Pb, Mb, y
# -----------------------------
# Cross-attention models
# -----------------------------
class CrossAttnPooled(nn.Module):
"""
pooled vectors -> treat as single-token sequences for cross attention
"""
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"n1t": nn.LayerNorm(hidden),
"n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden),
"n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def forward(self, t_vec, b_vec):
# (B,Ht),(B,Hb)
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
for L in self.layers:
t_attn, _ = L["attn_tb"](t, b, b)
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
b_attn, _ = L["attn_bt"](b, t, t)
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
t0 = t[0]
b0 = b[0]
z = torch.cat([t0, b0], dim=-1)
h = self.shared(z)
return self.reg(h).squeeze(-1), self.cls(h)
class CrossAttnUnpooled(nn.Module):
"""
token sequences with masks; alternating cross attention.
"""
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"n1t": nn.LayerNorm(hidden),
"n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden),
"n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def masked_mean(self, X, M):
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
def forward(self, T, Mt, B, Mb):
# T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
T = self.t_proj(T)
Bx = self.b_proj(B)
kp_t = ~Mt # key_padding_mask True = pad
kp_b = ~Mb
for L in self.layers:
# T attends to B
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
T = L["n1t"](T + T_attn)
T = L["n2t"](T + L["fft"](T))
# B attends to T
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
Bx = L["n1b"](Bx + B_attn)
Bx = L["n2b"](Bx + L["ffb"](Bx))
t_pool = self.masked_mean(T, Mt)
b_pool = self.masked_mean(Bx, Mb)
z = torch.cat([t_pool, b_pool], dim=-1)
h = self.shared(z)
return self.reg(h).squeeze(-1), self.cls(h)
# -----------------------------
# Train/eval
# -----------------------------
@torch.no_grad()
def eval_spearman_pooled(model, loader):
model.eval()
ys, ps = [], []
for t, b, y in loader:
t = t.to(DEVICE, non_blocking=True)
b = b.to(DEVICE, non_blocking=True)
pred, _ = model(t, b)
ys.append(y.numpy())
ps.append(pred.detach().cpu().numpy())
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
@torch.no_grad()
def eval_spearman_unpooled(model, loader):
model.eval()
ys, ps = [], []
for T, Mt, B, Mb, y in loader:
T = T.to(DEVICE, non_blocking=True)
Mt = Mt.to(DEVICE, non_blocking=True)
B = B.to(DEVICE, non_blocking=True)
Mb = Mb.to(DEVICE, non_blocking=True)
pred, _ = model(T, Mt, B, Mb)
ys.append(y.numpy())
ps.append(pred.detach().cpu().numpy())
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
model.train()
for t, b, y in loader:
t = t.to(DEVICE, non_blocking=True)
b = b.to(DEVICE, non_blocking=True)
y = y.to(DEVICE, non_blocking=True)
y_cls = affinity_to_class_tensor(y)
opt.zero_grad(set_to_none=True)
pred, logits = model(t, b)
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
L.backward()
if clip is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
opt.step()
def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
model.train()
for T, Mt, B, Mb, y in loader:
T = T.to(DEVICE, non_blocking=True)
Mt = Mt.to(DEVICE, non_blocking=True)
B = B.to(DEVICE, non_blocking=True)
Mb = Mb.to(DEVICE, non_blocking=True)
y = y.to(DEVICE, non_blocking=True)
y_cls = affinity_to_class_tensor(y)
opt.zero_grad(set_to_none=True)
pred, logits = model(T, Mt, B, Mb)
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
L.backward()
if clip is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
opt.step()
# -----------------------------
# Optuna objective
# -----------------------------
def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
dropout = trial.suggest_float("dropout", 0.0, 0.4)
hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
n_heads = trial.suggest_categorical("n_heads", [4, 8])
n_layers = trial.suggest_int("n_layers", 1, 4)
cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
# infer dims from first row
if mode == "pooled":
Ht = len(train_ds[0]["target_embedding"])
Hb = len(train_ds[0]["binder_embedding"])
collate = collate_pair_pooled
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
eval_fn = eval_spearman_pooled
train_fn = train_one_epoch_pooled
else:
Ht = len(train_ds[0]["target_embedding"][0])
Hb = len(train_ds[0]["binder_embedding"][0])
collate = collate_pair_unpooled
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
eval_fn = eval_spearman_unpooled
train_fn = train_one_epoch_unpooled
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
loss_reg = nn.MSELoss()
loss_cls = nn.CrossEntropyLoss()
best = -1e9
bad = 0
patience = 10
for ep in range(1, 61):
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
rho = eval_fn(model, val_loader)
trial.report(rho, ep)
if trial.should_prune():
raise optuna.TrialPruned()
if rho > best + 1e-6:
best = rho
bad = 0
else:
bad += 1
if bad >= patience:
break
return float(best)
# -----------------------------
# Run: optuna + refit best
# -----------------------------
def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
train_ds, val_ds = load_split_paired(dataset_path)
print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
best = study.best_trial
best_params = dict(best.params)
# refit longer
lr = float(best_params["lr"])
wd = float(best_params["weight_decay"])
dropout = float(best_params["dropout"])
hidden = int(best_params["hidden_dim"])
n_heads = int(best_params["n_heads"])
n_layers = int(best_params["n_layers"])
cls_w = float(best_params["cls_weight"])
batch = int(best_params["batch_size"])
loss_reg = nn.MSELoss()
loss_cls = nn.CrossEntropyLoss()
if mode == "pooled":
Ht = len(train_ds[0]["target_embedding"])
Hb = len(train_ds[0]["binder_embedding"])
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
collate = collate_pair_pooled
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
eval_fn = eval_spearman_pooled
train_fn = train_one_epoch_pooled
else:
Ht = len(train_ds[0]["target_embedding"][0])
Hb = len(train_ds[0]["binder_embedding"][0])
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
collate = collate_pair_unpooled
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
eval_fn = eval_spearman_unpooled
train_fn = train_one_epoch_unpooled
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
best_rho = -1e9
bad = 0
patience = 20
best_state = None
for ep in range(1, 201):
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
rho = eval_fn(model, val_loader)
if rho > best_rho + 1e-6:
best_rho = rho
bad = 0
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
else:
bad += 1
if bad >= patience:
break
if best_state is not None:
model.load_state_dict(best_state)
# save
torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
with open(out_dir / "best_params.json", "w") as f:
json.dump(best_params, f, indent=2)
print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
ap.add_argument("--out_dir", type=str, required=True)
ap.add_argument("--n_trials", type=int, default=50)
args = ap.parse_args()
run(
dataset_path=args.dataset_path,
out_dir=args.out_dir,
mode=args.mode,
n_trials=args.n_trials,
)