Joblib
PeptiVerse / training_data_cleaned /binding_affinity_split.py
ynuozhang
update code
baf3373
import os
import math
from pathlib import Path
import sys
from contextlib import contextmanager
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
from lightning.pytorch import seed_everything
seed_everything(1986)
CSV_PATH = Path("./Classifier_Weight/training_data_cleaned/binding_affinity/c-binding_with_openfold_scores.csv")
OUT_ROOT = Path(
"./Classifier_Weight/training_data_cleaned/binding_affinity"
)
# WT embedding model
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
WT_MAX_LEN = 1022
WT_BATCH = 32
# SMILES embedding model + tokenizer
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt"
TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt"
SMI_MAX_LEN = 768
SMI_BATCH = 128
# Split config
TRAIN_FRAC = 0.80
RANDOM_SEED = 1986
AFFINITY_Q_BINS = 30
COL_SEQ1 = "seq1"
COL_SEQ2 = "seq2"
COL_AFF = "affinity"
COL_F2S = "Fasta2SMILES"
COL_REACT = "REACT_SMILES"
COL_WT_IPTM = "wt_iptm_score"
COL_SMI_IPTM = "smiles_iptm_score"
# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
QUIET = True
USE_TQDM = False
LOG_FILE = None
def log(msg: str):
if LOG_FILE is not None:
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
with open(LOG_FILE, "a") as f:
f.write(msg.rstrip() + "\n")
if not QUIET:
print(msg)
def pbar(it, **kwargs):
return tqdm(it, **kwargs) if USE_TQDM else it
@contextmanager
def section(title: str):
log(f"\n=== {title} ===")
yield
log(f"=== done: {title} ===")
# -------------------------
# Helpers
# -------------------------
def has_uaa(seq: str) -> bool:
return "X" in str(seq).upper()
def affinity_to_class(a: float) -> str:
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
if a >= 9.0:
return "High"
elif a >= 7.0:
return "Moderate"
else:
return "Low"
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
try:
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
strat_col = "aff_bin"
except Exception:
df["aff_bin"] = df["affinity_class"]
strat_col = "aff_bin"
rng = np.random.RandomState(RANDOM_SEED)
df["split"] = None
for _, g in df.groupby(strat_col, observed=True):
idx = g.index.to_numpy()
rng.shuffle(idx)
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
df.loc[idx[:n_train], "split"] = "train"
df.loc[idx[n_train:], "split"] = "val"
df["split"] = df["split"].fillna("train")
return df
def _summ(x):
x = np.asarray(x, dtype=float)
x = x[~np.isnan(x)]
if len(x) == 0:
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
return {
"n": int(len(x)),
"mean": float(np.mean(x)),
"std": float(np.std(x)),
"p50": float(np.quantile(x, 0.50)),
"p95": float(np.quantile(x, 0.95)),
}
def _len_stats(seqs):
lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
if len(lens) == 0:
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
return {
"n": int(len(lens)),
"mean": float(lens.mean()),
"std": float(lens.std()),
"p50": float(np.quantile(lens, 0.50)),
"p95": float(np.quantile(lens, 0.95)),
}
def verify_split_before_embedding(
df2: pd.DataFrame,
affinity_col: str,
split_col: str,
seq_col: str,
iptm_col: str,
aff_class_col: str = "affinity_class",
aff_bins: int = 30,
save_report_prefix: str | None = None,
verbose: bool = False,
):
df2 = df2.copy()
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
assert split_col in df2.columns, f"Missing split col: {split_col}"
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
try:
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
except Exception:
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
va = df2[df2[split_col] == "val"].reset_index(drop=True)
tr_aff = _summ(tr[affinity_col].to_numpy())
va_aff = _summ(va[affinity_col].to_numpy())
tr_len = _len_stats(tr[seq_col].tolist())
va_len = _len_stats(va[seq_col].tolist())
# bin drift
bin_ct = (
df2.groupby([split_col, "_aff_bin_dbg"])
.size()
.groupby(level=0)
.apply(lambda s: s / s.sum())
)
tr_bins = bin_ct.loc["train"]
va_bins = bin_ct.loc["val"]
all_bins = tr_bins.index.union(va_bins.index)
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
msg = (
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
f"max_bin_diff={max_bin_diff:.4f}"
)
log(msg)
if verbose and (not QUIET):
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
print("\n[verbose] affinity_class counts:\n", class_ct)
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
if save_report_prefix is not None:
out = Path(save_report_prefix)
out.parent.mkdir(parents=True, exist_ok=True)
stats_df = pd.DataFrame([
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
])
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
# -------------------------
# WT pooled (ESM2)
# -------------------------
@torch.no_grad()
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
embs = []
for i in pbar(range(0, len(seqs), batch_size)):
batch = seqs[i:i + batch_size]
inputs = tokenizer(
batch,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
out = model(**inputs)
h = out.last_hidden_state # (B, L, H)
attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
summed = (h * attn).sum(dim=1) # (B, H)
denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
pooled = (summed / denom).detach().cpu().numpy()
embs.append(pooled)
return np.vstack(embs)
# -------------------------
# WT unpooled (ESM2)
# -------------------------
@torch.no_grad()
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
tok = {k: v.to(DEVICE) for k, v in tok.items()}
out = model(**tok)
h = out.last_hidden_state[0] # (L, H)
attn = tok["attention_mask"][0].bool() # (L,)
ids = tok["input_ids"][0]
keep = attn.clone()
if cls_id is not None:
keep &= (ids != cls_id)
if eos_id is not None:
keep &= (ids != eos_id)
return h[keep].detach().cpu().to(torch.float16).numpy()
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
"""
Expects df_split to have:
- target_sequence (seq1)
- sequence (binder seq2; WT binder)
- label, affinity_class, COL_AFF, COL_WT_IPTM
Saves a dataset where each row contains BOTH:
- target_embedding (Lt,H), target_attention_mask, target_length
- binder_embedding (Lb,H), binder_attention_mask, binder_length
"""
cls_id = tokenizer.cls_token_id
eos_id = tokenizer.eos_token_id
H = model.config.hidden_size
features = Features({
"target_sequence": Value("string"),
"sequence": Value("string"),
"label": Value("float32"),
"affinity": Value("float32"),
"affinity_class": Value("string"),
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
"target_attention_mask": HFSequence(Value("int8")),
"target_length": Value("int64"),
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
"binder_attention_mask": HFSequence(Value("int8")),
"binder_length": Value("int64"),
COL_WT_IPTM: Value("float32"),
COL_AFF: Value("float32"),
})
def gen_rows(df: pd.DataFrame):
for r in pbar(df.itertuples(index=False), total=len(df)):
tgt = str(getattr(r, "target_sequence")).strip()
bnd = str(getattr(r, "sequence")).strip()
y = float(getattr(r, "label"))
aff = float(getattr(r, COL_AFF))
acls = str(getattr(r, "affinity_class"))
iptm = getattr(r, COL_WT_IPTM)
iptm = float(iptm) if pd.notna(iptm) else np.nan
# token embeddings for target + binder (both ESM)
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
t_list = t_emb.tolist()
b_list = b_emb.tolist()
Lt = len(t_list)
Lb = len(b_list)
yield {
"target_sequence": tgt,
"sequence": bnd,
"label": np.float32(y),
"affinity": np.float32(aff),
"affinity_class": acls,
"target_embedding": t_list,
"target_attention_mask": [1] * Lt,
"target_length": int(Lt),
"binder_embedding": b_list,
"binder_attention_mask": [1] * Lb,
"binder_length": int(Lb),
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
COL_AFF: np.float32(aff),
}
out_dir.mkdir(parents=True, exist_ok=True)
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
return ds
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
smi_tok, smi_roformer):
"""
df_split must have:
- target_sequence (seq1)
- sequence (binder smiles string)
- label, affinity_class, COL_AFF, COL_SMI_IPTM
Saves rows with:
target_embedding (Lt,Ht) from ESM
binder_embedding (Lb,Hb) from PeptideCLM
"""
cls_id = wt_tokenizer.cls_token_id
eos_id = wt_tokenizer.eos_token_id
Ht = wt_model_unpooled.config.hidden_size
Hb = getattr(smi_roformer.config, "hidden_size", None)
if Hb is None:
Hb = getattr(smi_roformer.config, "dim", None)
if Hb is None:
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
features = Features({
"target_sequence": Value("string"),
"sequence": Value("string"),
"label": Value("float32"),
"affinity": Value("float32"),
"affinity_class": Value("string"),
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
"target_attention_mask": HFSequence(Value("int8")),
"target_length": Value("int64"),
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
"binder_attention_mask": HFSequence(Value("int8")),
"binder_length": Value("int64"),
COL_SMI_IPTM: Value("float32"),
COL_AFF: Value("float32"),
})
def gen_rows(df: pd.DataFrame):
for r in pbar(df.itertuples(index=False), total=len(df)):
tgt = str(getattr(r, "target_sequence")).strip()
bnd = str(getattr(r, "sequence")).strip()
y = float(getattr(r, "label"))
aff = float(getattr(r, COL_AFF))
acls = str(getattr(r, "affinity_class"))
iptm = getattr(r, COL_SMI_IPTM)
iptm = float(iptm) if pd.notna(iptm) else np.nan
# target token embeddings (ESM)
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
t_list = t_emb.tolist()
Lt = len(t_list)
# binder token embeddings (PeptideCLM)
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
)
b_emb = tok_list[0]
b_list = b_emb.tolist()
Lb = int(lengths[0])
b_mask = mask_list[0].astype(np.int8).tolist()
yield {
"target_sequence": tgt,
"sequence": bnd,
"label": np.float32(y),
"affinity": np.float32(aff),
"affinity_class": acls,
"target_embedding": t_list,
"target_attention_mask": [1] * Lt,
"target_length": int(Lt),
"binder_embedding": b_list,
"binder_attention_mask": [int(x) for x in b_mask],
"binder_length": int(Lb),
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
COL_AFF: np.float32(aff),
}
out_dir.mkdir(parents=True, exist_ok=True)
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
return ds
# -------------------------
# SMILES pooled + unpooled (PeptideCLM)
# -------------------------
def get_special_ids(tokenizer_obj):
cand = [
getattr(tokenizer_obj, "pad_token_id", None),
getattr(tokenizer_obj, "cls_token_id", None),
getattr(tokenizer_obj, "sep_token_id", None),
getattr(tokenizer_obj, "bos_token_id", None),
getattr(tokenizer_obj, "eos_token_id", None),
getattr(tokenizer_obj, "mask_token_id", None),
]
return sorted({x for x in cand if x is not None})
@torch.no_grad()
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
tok = tokenizer_obj(
batch_sequences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
input_ids = tok["input_ids"].to(DEVICE)
attention_mask = tok["attention_mask"].to(DEVICE)
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state # (B, L, H)
special_ids = get_special_ids(tokenizer_obj)
valid = attention_mask.bool()
if len(special_ids) > 0:
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
if hasattr(torch, "isin"):
valid = valid & (~torch.isin(input_ids, sid))
else:
m = torch.zeros_like(valid)
for s in special_ids:
m |= (input_ids == s)
valid = valid & (~m)
valid_f = valid.unsqueeze(-1).float()
summed = torch.sum(last_hidden * valid_f, dim=1)
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
pooled = (summed / denom).detach().cpu().numpy()
token_emb_list, mask_list, lengths = [], [], []
for b in range(last_hidden.shape[0]):
emb = last_hidden[b, valid[b]] # (Li, H)
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
li = emb.shape[0]
lengths.append(int(li))
mask_list.append(np.ones((li,), dtype=np.int8))
return pooled, token_emb_list, mask_list, lengths
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
pooled_all = []
token_emb_all = []
mask_all = []
lengths_all = []
for i in pbar(range(0, len(seqs), batch_size)):
batch = seqs[i:i + batch_size]
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
batch, tokenizer_obj, model_roformer, max_length
)
pooled_all.append(pooled)
token_emb_all.extend(tok_list)
mask_all.extend(m_list)
lengths_all.extend(lens)
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
# compute target pooled embeddings once
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
wt_train_tgt_emb = wt_pooled_embeddings(
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
)
wt_val_tgt_emb = wt_pooled_embeddings(
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
)
# build dict: target_sequence -> embedding
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
# -------------------------
# Main
# -------------------------
def main():
log(f"[INFO] DEVICE: {DEVICE}")
OUT_ROOT.mkdir(parents=True, exist_ok=True)
with section("load csv + dedup"):
df = pd.read_csv(CSV_PATH)
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
if c in df.columns:
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
# Dedup
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
print("Rows after dedup on", DEDUP_COLS, ":", len(df))
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
missing = [c for c in need if c not in df.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
# numeric affinity for both branches
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
# WT subset + SMILES subset separately
with section("prepare wt/smiles subsets"):
# WT: requires a canonical peptide sequence (no X) + affinity
df_wt = df.copy()
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
# SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
df_smi = df.copy()
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
df_smi = df_smi[
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequence
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
# Split separately
with section("split wt and smiles separately"):
df_wt2 = make_distribution_matched_split(df_wt)
df_smi2 = make_distribution_matched_split(df_smi)
# save split tables
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
df_wt2.to_csv(wt_split_csv, index=False)
df_smi2.to_csv(smi_split_csv, index=False)
log(f"Saved WT split meta: {wt_split_csv}")
log(f"Saved SMILES split meta: {smi_split_csv}")
verify_split_before_embedding(
df2=df_wt2,
affinity_col=COL_AFF,
split_col="split",
seq_col="wt_sequence",
iptm_col=COL_WT_IPTM,
aff_class_col="affinity_class",
aff_bins=AFFINITY_Q_BINS,
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
verbose=False,
)
verify_split_before_embedding(
df2=df_smi2,
affinity_col=COL_AFF,
split_col="split",
seq_col="smiles_sequence",
iptm_col=COL_SMI_IPTM,
aff_class_col="affinity_class",
aff_bins=AFFINITY_Q_BINS,
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
verbose=False,
)
# Prepare split views
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
out = df_in.copy()
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
# -------------------------
# Split views
# -------------------------
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
# =========================
# TARGET pooled embeddings (ESM) — SEPARATE per branch
# =========================
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
# ---- WT targets ----
wt_train_tgt_emb = wt_pooled_embeddings(
wt_train["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
wt_val_tgt_emb = wt_pooled_embeddings(
wt_val["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
# ---- SMILES targets ----
smi_train_tgt_emb = wt_pooled_embeddings(
smi_train["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
smi_val_tgt_emb = wt_pooled_embeddings(
smi_val["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
# =========================
# WT pooled binder embeddings (binder = WT peptide)
# =========================
with section("WT pooled binder embeddings + save"):
wt_train_emb = wt_pooled_embeddings(
wt_train["sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
wt_val_emb = wt_pooled_embeddings(
wt_val["sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
wt_train_ds = Dataset.from_dict({
"target_sequence": wt_train["target_sequence"].tolist(),
"sequence": wt_train["sequence"].tolist(),
"label": wt_train["label"].astype(float).tolist(),
"target_embedding": wt_train_tgt_emb,
"embedding": wt_train_emb,
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
"affinity_class": wt_train["affinity_class"].tolist(),
})
wt_val_ds = Dataset.from_dict({
"target_sequence": wt_val["target_sequence"].tolist(),
"sequence": wt_val["sequence"].tolist(),
"label": wt_val["label"].astype(float).tolist(),
"target_embedding": wt_val_tgt_emb,
"embedding": wt_val_emb,
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
"affinity_class": wt_val["affinity_class"].tolist(),
})
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
log(f"Saved WT pooled -> {wt_pooled_out}")
# =========================
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
# =========================
with section("SMILES pooled binder embeddings + save"):
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
smi_roformer = (
AutoModelForMaskedLM
.from_pretrained(SMI_MODEL_NAME)
.roformer
.to(DEVICE)
.eval()
)
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
smi_train["sequence"].astype(str).str.strip().tolist(),
smi_tok, smi_roformer,
batch_size=SMI_BATCH,
max_length=SMI_MAX_LEN,
)
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
smi_val["sequence"].astype(str).str.strip().tolist(),
smi_tok, smi_roformer,
batch_size=SMI_BATCH,
max_length=SMI_MAX_LEN,
)
smi_train_ds = Dataset.from_dict({
"target_sequence": smi_train["target_sequence"].tolist(),
"sequence": smi_train["sequence"].tolist(),
"label": smi_train["label"].astype(float).tolist(),
"target_embedding": smi_train_tgt_emb,
"embedding": smi_train_pooled.astype(np.float32),
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
"affinity_class": smi_train["affinity_class"].tolist(),
})
smi_val_ds = Dataset.from_dict({
"target_sequence": smi_val["target_sequence"].tolist(),
"sequence": smi_val["sequence"].tolist(),
"label": smi_val["label"].astype(float).tolist(),
"target_embedding": smi_val_tgt_emb,
"embedding": smi_val_pooled.astype(np.float32),
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
"affinity_class": smi_val["affinity_class"].tolist(),
})
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
log(f"Saved SMILES pooled -> {smi_pooled_out}")
# =========================
# WT unpooled paired (ESM target + ESM binder) + save
# =========================
with section("WT unpooled paired embeddings + save"):
wt_tok_unpooled = wt_tok # reuse tokenizer
wt_esm_unpooled = wt_esm # reuse model
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
wt_unpooled_dd = DatasetDict({
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
wt_tok_unpooled, wt_esm_unpooled),
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
wt_tok_unpooled, wt_esm_unpooled),
})
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
log(f"Saved WT unpooled -> {wt_unpooled_out}")
# =========================
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
# =========================
with section("SMILES unpooled paired embeddings + save"):
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
smi_unpooled_dd = DatasetDict({
"train": build_smiles_unpooled_paired_dataset(
smi_train, smi_unpooled_out / "train",
wt_tok, wt_esm,
smi_tok, smi_roformer
),
"val": build_smiles_unpooled_paired_dataset(
smi_val, smi_unpooled_out / "val",
wt_tok, wt_esm,
smi_tok, smi_roformer
),
})
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
if __name__ == "__main__":
main()