|
|
import os, json |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
import optuna |
|
|
from datasets import load_from_disk, DatasetDict |
|
|
from scipy.stats import spearmanr |
|
|
from lightning.pytorch import seed_everything |
|
|
seed_everything(1986) |
|
|
|
|
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
|
|
rho = spearmanr(y_true, y_pred).correlation |
|
|
if rho is None or np.isnan(rho): |
|
|
return 0.0 |
|
|
return float(rho) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor: |
|
|
high = y >= 9.0 |
|
|
low = y < 7.0 |
|
|
mid = ~(high | low) |
|
|
cls = torch.zeros_like(y, dtype=torch.long) |
|
|
cls[mid] = 1 |
|
|
cls[low] = 2 |
|
|
return cls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_split_paired(path: str): |
|
|
dd = load_from_disk(path) |
|
|
if not isinstance(dd, DatasetDict): |
|
|
raise ValueError(f"Expected DatasetDict at {path}") |
|
|
if "train" not in dd or "val" not in dd: |
|
|
raise ValueError(f"DatasetDict missing train/val at {path}") |
|
|
return dd["train"], dd["val"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_pair_pooled(batch): |
|
|
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) |
|
|
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) |
|
|
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
|
|
return Pt, Pb, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_pair_unpooled(batch): |
|
|
B = len(batch) |
|
|
Ht = len(batch[0]["target_embedding"][0]) |
|
|
Hb = len(batch[0]["binder_embedding"][0]) |
|
|
Lt_max = max(int(x["target_length"]) for x in batch) |
|
|
Lb_max = max(int(x["binder_length"]) for x in batch) |
|
|
|
|
|
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32) |
|
|
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32) |
|
|
Mt = torch.zeros(B, Lt_max, dtype=torch.bool) |
|
|
Mb = torch.zeros(B, Lb_max, dtype=torch.bool) |
|
|
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
|
|
|
|
|
for i, x in enumerate(batch): |
|
|
t = torch.tensor(x["target_embedding"], dtype=torch.float32) |
|
|
b = torch.tensor(x["binder_embedding"], dtype=torch.float32) |
|
|
lt, lb = t.shape[0], b.shape[0] |
|
|
Pt[i, :lt] = t |
|
|
Pb[i, :lb] = b |
|
|
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool) |
|
|
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool) |
|
|
|
|
|
return Pt, Mt, Pb, Mb, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttnPooled(nn.Module): |
|
|
""" |
|
|
pooled vectors -> treat as single-token sequences for cross attention |
|
|
""" |
|
|
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
|
|
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleDict({ |
|
|
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
|
|
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
|
|
"n1t": nn.LayerNorm(hidden), |
|
|
"n2t": nn.LayerNorm(hidden), |
|
|
"n1b": nn.LayerNorm(hidden), |
|
|
"n2b": nn.LayerNorm(hidden), |
|
|
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
})) |
|
|
|
|
|
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
|
|
self.reg = nn.Linear(hidden, 1) |
|
|
self.cls = nn.Linear(hidden, 3) |
|
|
|
|
|
def forward(self, t_vec, b_vec): |
|
|
|
|
|
t = self.t_proj(t_vec).unsqueeze(0) |
|
|
b = self.b_proj(b_vec).unsqueeze(0) |
|
|
|
|
|
for L in self.layers: |
|
|
t_attn, _ = L["attn_tb"](t, b, b) |
|
|
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
|
|
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
|
|
|
|
|
b_attn, _ = L["attn_bt"](b, t, t) |
|
|
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
|
|
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
|
|
|
|
|
t0 = t[0] |
|
|
b0 = b[0] |
|
|
z = torch.cat([t0, b0], dim=-1) |
|
|
h = self.shared(z) |
|
|
return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
|
|
|
|
class CrossAttnUnpooled(nn.Module): |
|
|
""" |
|
|
token sequences with masks; alternating cross attention. |
|
|
""" |
|
|
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
|
|
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleDict({ |
|
|
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
|
|
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
|
|
"n1t": nn.LayerNorm(hidden), |
|
|
"n2t": nn.LayerNorm(hidden), |
|
|
"n1b": nn.LayerNorm(hidden), |
|
|
"n2b": nn.LayerNorm(hidden), |
|
|
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
})) |
|
|
|
|
|
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
|
|
self.reg = nn.Linear(hidden, 1) |
|
|
self.cls = nn.Linear(hidden, 3) |
|
|
|
|
|
def masked_mean(self, X, M): |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
return (X * Mf).sum(dim=1) / denom |
|
|
|
|
|
def forward(self, T, Mt, B, Mb): |
|
|
|
|
|
T = self.t_proj(T) |
|
|
Bx = self.b_proj(B) |
|
|
|
|
|
kp_t = ~Mt |
|
|
kp_b = ~Mb |
|
|
|
|
|
for L in self.layers: |
|
|
|
|
|
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
|
|
T = L["n1t"](T + T_attn) |
|
|
T = L["n2t"](T + L["fft"](T)) |
|
|
|
|
|
|
|
|
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
|
|
Bx = L["n1b"](Bx + B_attn) |
|
|
Bx = L["n2b"](Bx + L["ffb"](Bx)) |
|
|
|
|
|
t_pool = self.masked_mean(T, Mt) |
|
|
b_pool = self.masked_mean(Bx, Mb) |
|
|
z = torch.cat([t_pool, b_pool], dim=-1) |
|
|
h = self.shared(z) |
|
|
return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def eval_spearman_pooled(model, loader): |
|
|
model.eval() |
|
|
ys, ps = [], [] |
|
|
for t, b, y in loader: |
|
|
t = t.to(DEVICE, non_blocking=True) |
|
|
b = b.to(DEVICE, non_blocking=True) |
|
|
pred, _ = model(t, b) |
|
|
ys.append(y.numpy()) |
|
|
ps.append(pred.detach().cpu().numpy()) |
|
|
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) |
|
|
|
|
|
@torch.no_grad() |
|
|
def eval_spearman_unpooled(model, loader): |
|
|
model.eval() |
|
|
ys, ps = [], [] |
|
|
for T, Mt, B, Mb, y in loader: |
|
|
T = T.to(DEVICE, non_blocking=True) |
|
|
Mt = Mt.to(DEVICE, non_blocking=True) |
|
|
B = B.to(DEVICE, non_blocking=True) |
|
|
Mb = Mb.to(DEVICE, non_blocking=True) |
|
|
pred, _ = model(T, Mt, B, Mb) |
|
|
ys.append(y.numpy()) |
|
|
ps.append(pred.detach().cpu().numpy()) |
|
|
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) |
|
|
|
|
|
def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): |
|
|
model.train() |
|
|
for t, b, y in loader: |
|
|
t = t.to(DEVICE, non_blocking=True) |
|
|
b = b.to(DEVICE, non_blocking=True) |
|
|
y = y.to(DEVICE, non_blocking=True) |
|
|
y_cls = affinity_to_class_tensor(y) |
|
|
|
|
|
opt.zero_grad(set_to_none=True) |
|
|
pred, logits = model(t, b) |
|
|
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) |
|
|
L.backward() |
|
|
if clip is not None: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip) |
|
|
opt.step() |
|
|
|
|
|
def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): |
|
|
model.train() |
|
|
for T, Mt, B, Mb, y in loader: |
|
|
T = T.to(DEVICE, non_blocking=True) |
|
|
Mt = Mt.to(DEVICE, non_blocking=True) |
|
|
B = B.to(DEVICE, non_blocking=True) |
|
|
Mb = Mb.to(DEVICE, non_blocking=True) |
|
|
y = y.to(DEVICE, non_blocking=True) |
|
|
y_cls = affinity_to_class_tensor(y) |
|
|
|
|
|
opt.zero_grad(set_to_none=True) |
|
|
pred, logits = model(T, Mt, B, Mb) |
|
|
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) |
|
|
L.backward() |
|
|
if clip is not None: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip) |
|
|
opt.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float: |
|
|
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) |
|
|
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True) |
|
|
dropout = trial.suggest_float("dropout", 0.0, 0.4) |
|
|
hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768]) |
|
|
n_heads = trial.suggest_categorical("n_heads", [4, 8]) |
|
|
n_layers = trial.suggest_int("n_layers", 1, 4) |
|
|
cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True) |
|
|
batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128]) |
|
|
|
|
|
|
|
|
if mode == "pooled": |
|
|
Ht = len(train_ds[0]["target_embedding"]) |
|
|
Hb = len(train_ds[0]["binder_embedding"]) |
|
|
collate = collate_pair_pooled |
|
|
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
|
|
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
eval_fn = eval_spearman_pooled |
|
|
train_fn = train_one_epoch_pooled |
|
|
|
|
|
else: |
|
|
Ht = len(train_ds[0]["target_embedding"][0]) |
|
|
Hb = len(train_ds[0]["binder_embedding"][0]) |
|
|
collate = collate_pair_unpooled |
|
|
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
|
|
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
eval_fn = eval_spearman_unpooled |
|
|
train_fn = train_one_epoch_unpooled |
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
|
|
loss_reg = nn.MSELoss() |
|
|
loss_cls = nn.CrossEntropyLoss() |
|
|
|
|
|
best = -1e9 |
|
|
bad = 0 |
|
|
patience = 10 |
|
|
|
|
|
for ep in range(1, 61): |
|
|
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
|
|
rho = eval_fn(model, val_loader) |
|
|
|
|
|
trial.report(rho, ep) |
|
|
if trial.should_prune(): |
|
|
raise optuna.TrialPruned() |
|
|
|
|
|
if rho > best + 1e-6: |
|
|
best = rho |
|
|
bad = 0 |
|
|
else: |
|
|
bad += 1 |
|
|
if bad >= patience: |
|
|
break |
|
|
|
|
|
return float(best) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50): |
|
|
out_dir = Path(out_dir) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
train_ds, val_ds = load_split_paired(dataset_path) |
|
|
print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}") |
|
|
|
|
|
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
|
|
study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials) |
|
|
|
|
|
study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False) |
|
|
best = study.best_trial |
|
|
best_params = dict(best.params) |
|
|
|
|
|
|
|
|
lr = float(best_params["lr"]) |
|
|
wd = float(best_params["weight_decay"]) |
|
|
dropout = float(best_params["dropout"]) |
|
|
hidden = int(best_params["hidden_dim"]) |
|
|
n_heads = int(best_params["n_heads"]) |
|
|
n_layers = int(best_params["n_layers"]) |
|
|
cls_w = float(best_params["cls_weight"]) |
|
|
batch = int(best_params["batch_size"]) |
|
|
|
|
|
loss_reg = nn.MSELoss() |
|
|
loss_cls = nn.CrossEntropyLoss() |
|
|
|
|
|
if mode == "pooled": |
|
|
Ht = len(train_ds[0]["target_embedding"]) |
|
|
Hb = len(train_ds[0]["binder_embedding"]) |
|
|
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
|
|
collate = collate_pair_pooled |
|
|
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
eval_fn = eval_spearman_pooled |
|
|
train_fn = train_one_epoch_pooled |
|
|
else: |
|
|
Ht = len(train_ds[0]["target_embedding"][0]) |
|
|
Hb = len(train_ds[0]["binder_embedding"][0]) |
|
|
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
|
|
collate = collate_pair_unpooled |
|
|
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
eval_fn = eval_spearman_unpooled |
|
|
train_fn = train_one_epoch_unpooled |
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
|
|
|
|
|
best_rho = -1e9 |
|
|
bad = 0 |
|
|
patience = 20 |
|
|
best_state = None |
|
|
|
|
|
for ep in range(1, 201): |
|
|
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
|
|
rho = eval_fn(model, val_loader) |
|
|
|
|
|
if rho > best_rho + 1e-6: |
|
|
best_rho = rho |
|
|
bad = 0 |
|
|
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
|
|
else: |
|
|
bad += 1 |
|
|
if bad >= patience: |
|
|
break |
|
|
|
|
|
if best_state is not None: |
|
|
model.load_state_dict(best_state) |
|
|
|
|
|
|
|
|
torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt") |
|
|
with open(out_dir / "best_params.json", "w") as f: |
|
|
json.dump(best_params, f, indent=2) |
|
|
|
|
|
print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)") |
|
|
ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True) |
|
|
ap.add_argument("--out_dir", type=str, required=True) |
|
|
ap.add_argument("--n_trials", type=int, default=50) |
|
|
args = ap.parse_args() |
|
|
|
|
|
run( |
|
|
dataset_path=args.dataset_path, |
|
|
out_dir=args.out_dir, |
|
|
mode=args.mode, |
|
|
n_trials=args.n_trials, |
|
|
) |
|
|
|