import os import math from pathlib import Path import sys from contextlib import contextmanager import numpy as np import pandas as pd import torch from tqdm import tqdm from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM from lightning.pytorch import seed_everything seed_everything(1986) CSV_PATH = Path("./Classifier_Weight/training_data_cleaned/binding_affinity/c-binding_with_openfold_scores.csv") OUT_ROOT = Path( "./Classifier_Weight/training_data_cleaned/binding_affinity" ) # WT embedding model WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D" WT_MAX_LEN = 1022 WT_BATCH = 32 # SMILES embedding model + tokenizer SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all" TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt" TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt" SMI_MAX_LEN = 768 SMI_BATCH = 128 # Split config TRAIN_FRAC = 0.80 RANDOM_SEED = 1986 AFFINITY_Q_BINS = 30 COL_SEQ1 = "seq1" COL_SEQ2 = "seq2" COL_AFF = "affinity" COL_F2S = "Fasta2SMILES" COL_REACT = "REACT_SMILES" COL_WT_IPTM = "wt_iptm_score" COL_SMI_IPTM = "smiles_iptm_score" # Device DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") QUIET = True USE_TQDM = False LOG_FILE = None def log(msg: str): if LOG_FILE is not None: Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True) with open(LOG_FILE, "a") as f: f.write(msg.rstrip() + "\n") if not QUIET: print(msg) def pbar(it, **kwargs): return tqdm(it, **kwargs) if USE_TQDM else it @contextmanager def section(title: str): log(f"\n=== {title} ===") yield log(f"=== done: {title} ===") # ------------------------- # Helpers # ------------------------- def has_uaa(seq: str) -> bool: return "X" in str(seq).upper() def affinity_to_class(a: float) -> str: # High: >= 9 ; Moderate: [7, 9) ; Low: < 7 if a >= 9.0: return "High" elif a >= 7.0: return "Moderate" else: return "Low" def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame: df = df.copy() df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") df = df.dropna(subset=[COL_AFF]).reset_index(drop=True) df["affinity_class"] = df[COL_AFF].apply(affinity_to_class) try: df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop") strat_col = "aff_bin" except Exception: df["aff_bin"] = df["affinity_class"] strat_col = "aff_bin" rng = np.random.RandomState(RANDOM_SEED) df["split"] = None for _, g in df.groupby(strat_col, observed=True): idx = g.index.to_numpy() rng.shuffle(idx) n_train = int(math.floor(len(idx) * TRAIN_FRAC)) df.loc[idx[:n_train], "split"] = "train" df.loc[idx[n_train:], "split"] = "val" df["split"] = df["split"].fillna("train") return df def _summ(x): x = np.asarray(x, dtype=float) x = x[~np.isnan(x)] if len(x) == 0: return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} return { "n": int(len(x)), "mean": float(np.mean(x)), "std": float(np.std(x)), "p50": float(np.quantile(x, 0.50)), "p95": float(np.quantile(x, 0.95)), } def _len_stats(seqs): lens = np.asarray([len(str(s)) for s in seqs], dtype=float) if len(lens) == 0: return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} return { "n": int(len(lens)), "mean": float(lens.mean()), "std": float(lens.std()), "p50": float(np.quantile(lens, 0.50)), "p95": float(np.quantile(lens, 0.95)), } def verify_split_before_embedding( df2: pd.DataFrame, affinity_col: str, split_col: str, seq_col: str, iptm_col: str, aff_class_col: str = "affinity_class", aff_bins: int = 30, save_report_prefix: str | None = None, verbose: bool = False, ): df2 = df2.copy() df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce") df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce") assert split_col in df2.columns, f"Missing split col: {split_col}" assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}" assert df2[affinity_col].notna().any(), "No valid affinity values after coercion." try: df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop") except Exception: df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str) tr = df2[df2[split_col] == "train"].reset_index(drop=True) va = df2[df2[split_col] == "val"].reset_index(drop=True) tr_aff = _summ(tr[affinity_col].to_numpy()) va_aff = _summ(va[affinity_col].to_numpy()) tr_len = _len_stats(tr[seq_col].tolist()) va_len = _len_stats(va[seq_col].tolist()) # bin drift bin_ct = ( df2.groupby([split_col, "_aff_bin_dbg"]) .size() .groupby(level=0) .apply(lambda s: s / s.sum()) ) tr_bins = bin_ct.loc["train"] va_bins = bin_ct.loc["val"] all_bins = tr_bins.index.union(va_bins.index) tr_bins = tr_bins.reindex(all_bins, fill_value=0.0) va_bins = va_bins.reindex(all_bins, fill_value=0.0) max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values))) msg = ( f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | " f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | " f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | " f"max_bin_diff={max_bin_diff:.4f}" ) log(msg) if verbose and (not QUIET): class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) class_prop = class_ct.div(class_ct.sum(axis=1), axis=0) print("\n[verbose] affinity_class counts:\n", class_ct) print("\n[verbose] affinity_class proportions:\n", class_prop.round(4)) if save_report_prefix is not None: out = Path(save_report_prefix) out.parent.mkdir(parents=True, exist_ok=True) stats_df = pd.DataFrame([ {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}}, {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}}, ]) class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index() stats_df.to_csv(out.with_suffix(".stats.csv"), index=False) class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False) # ------------------------- # WT pooled (ESM2) # ------------------------- @torch.no_grad() def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022): embs = [] for i in pbar(range(0, len(seqs), batch_size)): batch = seqs[i:i + batch_size] inputs = tokenizer( batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} out = model(**inputs) h = out.last_hidden_state # (B, L, H) attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1) summed = (h * attn).sum(dim=1) # (B, H) denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1) pooled = (summed / denom).detach().cpu().numpy() embs.append(pooled) return np.vstack(embs) # ------------------------- # WT unpooled (ESM2) # ------------------------- @torch.no_grad() def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022): tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") tok = {k: v.to(DEVICE) for k, v in tok.items()} out = model(**tok) h = out.last_hidden_state[0] # (L, H) attn = tok["attention_mask"][0].bool() # (L,) ids = tok["input_ids"][0] keep = attn.clone() if cls_id is not None: keep &= (ids != cls_id) if eos_id is not None: keep &= (ids != eos_id) return h[keep].detach().cpu().to(torch.float16).numpy() def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model): """ Expects df_split to have: - target_sequence (seq1) - sequence (binder seq2; WT binder) - label, affinity_class, COL_AFF, COL_WT_IPTM Saves a dataset where each row contains BOTH: - target_embedding (Lt,H), target_attention_mask, target_length - binder_embedding (Lb,H), binder_attention_mask, binder_length """ cls_id = tokenizer.cls_token_id eos_id = tokenizer.eos_token_id H = model.config.hidden_size features = Features({ "target_sequence": Value("string"), "sequence": Value("string"), "label": Value("float32"), "affinity": Value("float32"), "affinity_class": Value("string"), "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)), "target_attention_mask": HFSequence(Value("int8")), "target_length": Value("int64"), "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)), "binder_attention_mask": HFSequence(Value("int8")), "binder_length": Value("int64"), COL_WT_IPTM: Value("float32"), COL_AFF: Value("float32"), }) def gen_rows(df: pd.DataFrame): for r in pbar(df.itertuples(index=False), total=len(df)): tgt = str(getattr(r, "target_sequence")).strip() bnd = str(getattr(r, "sequence")).strip() y = float(getattr(r, "label")) aff = float(getattr(r, COL_AFF)) acls = str(getattr(r, "affinity_class")) iptm = getattr(r, COL_WT_IPTM) iptm = float(iptm) if pd.notna(iptm) else np.nan # token embeddings for target + binder (both ESM) t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H) b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H) t_list = t_emb.tolist() b_list = b_emb.tolist() Lt = len(t_list) Lb = len(b_list) yield { "target_sequence": tgt, "sequence": bnd, "label": np.float32(y), "affinity": np.float32(aff), "affinity_class": acls, "target_embedding": t_list, "target_attention_mask": [1] * Lt, "target_length": int(Lt), "binder_embedding": b_list, "binder_attention_mask": [1] * Lb, "binder_length": int(Lb), COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), COL_AFF: np.float32(aff), } out_dir.mkdir(parents=True, exist_ok=True) ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) ds.save_to_disk(str(out_dir), max_shard_size="1GB") return ds def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled, smi_tok, smi_roformer): """ df_split must have: - target_sequence (seq1) - sequence (binder smiles string) - label, affinity_class, COL_AFF, COL_SMI_IPTM Saves rows with: target_embedding (Lt,Ht) from ESM binder_embedding (Lb,Hb) from PeptideCLM """ cls_id = wt_tokenizer.cls_token_id eos_id = wt_tokenizer.eos_token_id Ht = wt_model_unpooled.config.hidden_size Hb = getattr(smi_roformer.config, "hidden_size", None) if Hb is None: Hb = getattr(smi_roformer.config, "dim", None) if Hb is None: raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.") features = Features({ "target_sequence": Value("string"), "sequence": Value("string"), "label": Value("float32"), "affinity": Value("float32"), "affinity_class": Value("string"), "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)), "target_attention_mask": HFSequence(Value("int8")), "target_length": Value("int64"), "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)), "binder_attention_mask": HFSequence(Value("int8")), "binder_length": Value("int64"), COL_SMI_IPTM: Value("float32"), COL_AFF: Value("float32"), }) def gen_rows(df: pd.DataFrame): for r in pbar(df.itertuples(index=False), total=len(df)): tgt = str(getattr(r, "target_sequence")).strip() bnd = str(getattr(r, "sequence")).strip() y = float(getattr(r, "label")) aff = float(getattr(r, COL_AFF)) acls = str(getattr(r, "affinity_class")) iptm = getattr(r, COL_SMI_IPTM) iptm = float(iptm) if pd.notna(iptm) else np.nan # target token embeddings (ESM) t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN) t_list = t_emb.tolist() Lt = len(t_list) # binder token embeddings (PeptideCLM) _, tok_list, mask_list, lengths = smiles_embed_batch_return_both( [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN ) b_emb = tok_list[0] b_list = b_emb.tolist() Lb = int(lengths[0]) b_mask = mask_list[0].astype(np.int8).tolist() yield { "target_sequence": tgt, "sequence": bnd, "label": np.float32(y), "affinity": np.float32(aff), "affinity_class": acls, "target_embedding": t_list, "target_attention_mask": [1] * Lt, "target_length": int(Lt), "binder_embedding": b_list, "binder_attention_mask": [int(x) for x in b_mask], "binder_length": int(Lb), COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), COL_AFF: np.float32(aff), } out_dir.mkdir(parents=True, exist_ok=True) ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) ds.save_to_disk(str(out_dir), max_shard_size="1GB") return ds # ------------------------- # SMILES pooled + unpooled (PeptideCLM) # ------------------------- def get_special_ids(tokenizer_obj): cand = [ getattr(tokenizer_obj, "pad_token_id", None), getattr(tokenizer_obj, "cls_token_id", None), getattr(tokenizer_obj, "sep_token_id", None), getattr(tokenizer_obj, "bos_token_id", None), getattr(tokenizer_obj, "eos_token_id", None), getattr(tokenizer_obj, "mask_token_id", None), ] return sorted({x for x in cand if x is not None}) @torch.no_grad() def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length): tok = tokenizer_obj( batch_sequences, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) input_ids = tok["input_ids"].to(DEVICE) attention_mask = tok["attention_mask"].to(DEVICE) outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask) last_hidden = outputs.last_hidden_state # (B, L, H) special_ids = get_special_ids(tokenizer_obj) valid = attention_mask.bool() if len(special_ids) > 0: sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long) if hasattr(torch, "isin"): valid = valid & (~torch.isin(input_ids, sid)) else: m = torch.zeros_like(valid) for s in special_ids: m |= (input_ids == s) valid = valid & (~m) valid_f = valid.unsqueeze(-1).float() summed = torch.sum(last_hidden * valid_f, dim=1) denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) pooled = (summed / denom).detach().cpu().numpy() token_emb_list, mask_list, lengths = [], [], [] for b in range(last_hidden.shape[0]): emb = last_hidden[b, valid[b]] # (Li, H) token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) li = emb.shape[0] lengths.append(int(li)) mask_list.append(np.ones((li,), dtype=np.int8)) return pooled, token_emb_list, mask_list, lengths def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length): pooled_all = [] token_emb_all = [] mask_all = [] lengths_all = [] for i in pbar(range(0, len(seqs), batch_size)): batch = seqs[i:i + batch_size] pooled, tok_list, m_list, lens = smiles_embed_batch_return_both( batch, tokenizer_obj, model_roformer, max_length ) pooled_all.append(pooled) token_emb_all.extend(tok_list) mask_all.extend(m_list) lengths_all.extend(lens) return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame): wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() # compute target pooled embeddings once tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist() tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist() wt_train_tgt_emb = wt_pooled_embeddings( tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN ) wt_val_tgt_emb = wt_pooled_embeddings( tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN ) # build dict: target_sequence -> embedding train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)} val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)} return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map # ------------------------- # Main # ------------------------- def main(): log(f"[INFO] DEVICE: {DEVICE}") OUT_ROOT.mkdir(parents=True, exist_ok=True) with section("load csv + dedup"): df = pd.read_csv(CSV_PATH) for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]: if c in df.columns: df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x) # Dedup DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT] df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True) print("Rows after dedup on", DEDUP_COLS, ":", len(df)) need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM] missing = [c for c in need if c not in df.columns] if missing: raise ValueError(f"Missing required columns: {missing}") # numeric affinity for both branches df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") # WT subset + SMILES subset separately with section("prepare wt/smiles subsets"): # WT: requires a canonical peptide sequence (no X) + affinity df_wt = df.copy() df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip() df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True) df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")] df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True) # SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES) df_smi = df.copy() df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True) df_smi = df_smi[ pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna() ].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequence is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False) df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S]) df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip() df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")] df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True) log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)") # Split separately with section("split wt and smiles separately"): df_wt2 = make_distribution_matched_split(df_wt) df_smi2 = make_distribution_matched_split(df_smi) # save split tables wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv" smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv" df_wt2.to_csv(wt_split_csv, index=False) df_smi2.to_csv(smi_split_csv, index=False) log(f"Saved WT split meta: {wt_split_csv}") log(f"Saved SMILES split meta: {smi_split_csv}") verify_split_before_embedding( df2=df_wt2, affinity_col=COL_AFF, split_col="split", seq_col="wt_sequence", iptm_col=COL_WT_IPTM, aff_class_col="affinity_class", aff_bins=AFFINITY_Q_BINS, save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"), verbose=False, ) verify_split_before_embedding( df2=df_smi2, affinity_col=COL_AFF, split_col="split", seq_col="smiles_sequence", iptm_col=COL_SMI_IPTM, aff_class_col="affinity_class", aff_bins=AFFINITY_Q_BINS, save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"), verbose=False, ) # Prepare split views def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame: out = df_in.copy() out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce") out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce") out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce") out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True) return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]] wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM) smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM) # ------------------------- # Split views # ------------------------- wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True) wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True) smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True) smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True) # ========================= # TARGET pooled embeddings (ESM) — SEPARATE per branch # ========================= with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"): wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() # ---- WT targets ---- wt_train_tgt_emb = wt_pooled_embeddings( wt_train["target_sequence"].astype(str).str.strip().tolist(), wt_tok, wt_esm, batch_size=WT_BATCH, max_length=WT_MAX_LEN, ).astype(np.float32) wt_val_tgt_emb = wt_pooled_embeddings( wt_val["target_sequence"].astype(str).str.strip().tolist(), wt_tok, wt_esm, batch_size=WT_BATCH, max_length=WT_MAX_LEN, ).astype(np.float32) # ---- SMILES targets ---- smi_train_tgt_emb = wt_pooled_embeddings( smi_train["target_sequence"].astype(str).str.strip().tolist(), wt_tok, wt_esm, batch_size=WT_BATCH, max_length=WT_MAX_LEN, ).astype(np.float32) smi_val_tgt_emb = wt_pooled_embeddings( smi_val["target_sequence"].astype(str).str.strip().tolist(), wt_tok, wt_esm, batch_size=WT_BATCH, max_length=WT_MAX_LEN, ).astype(np.float32) # ========================= # WT pooled binder embeddings (binder = WT peptide) # ========================= with section("WT pooled binder embeddings + save"): wt_train_emb = wt_pooled_embeddings( wt_train["sequence"].astype(str).str.strip().tolist(), wt_tok, wt_esm, batch_size=WT_BATCH, max_length=WT_MAX_LEN, ).astype(np.float32) wt_val_emb = wt_pooled_embeddings( wt_val["sequence"].astype(str).str.strip().tolist(), wt_tok, wt_esm, batch_size=WT_BATCH, max_length=WT_MAX_LEN, ).astype(np.float32) wt_train_ds = Dataset.from_dict({ "target_sequence": wt_train["target_sequence"].tolist(), "sequence": wt_train["sequence"].tolist(), "label": wt_train["label"].astype(float).tolist(), "target_embedding": wt_train_tgt_emb, "embedding": wt_train_emb, COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(), COL_AFF: wt_train[COL_AFF].astype(float).tolist(), "affinity_class": wt_train["affinity_class"].tolist(), }) wt_val_ds = Dataset.from_dict({ "target_sequence": wt_val["target_sequence"].tolist(), "sequence": wt_val["sequence"].tolist(), "label": wt_val["label"].astype(float).tolist(), "target_embedding": wt_val_tgt_emb, "embedding": wt_val_emb, COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(), COL_AFF: wt_val[COL_AFF].astype(float).tolist(), "affinity_class": wt_val["affinity_class"].tolist(), }) wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds}) wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled" wt_pooled_dd.save_to_disk(str(wt_pooled_out)) log(f"Saved WT pooled -> {wt_pooled_out}") # ========================= # SMILES pooled binder embeddings (binder = SMILES via PeptideCLM) # ========================= with section("SMILES pooled binder embeddings + save"): smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS) smi_roformer = ( AutoModelForMaskedLM .from_pretrained(SMI_MODEL_NAME) .roformer .to(DEVICE) .eval() ) smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both( smi_train["sequence"].astype(str).str.strip().tolist(), smi_tok, smi_roformer, batch_size=SMI_BATCH, max_length=SMI_MAX_LEN, ) smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both( smi_val["sequence"].astype(str).str.strip().tolist(), smi_tok, smi_roformer, batch_size=SMI_BATCH, max_length=SMI_MAX_LEN, ) smi_train_ds = Dataset.from_dict({ "target_sequence": smi_train["target_sequence"].tolist(), "sequence": smi_train["sequence"].tolist(), "label": smi_train["label"].astype(float).tolist(), "target_embedding": smi_train_tgt_emb, "embedding": smi_train_pooled.astype(np.float32), COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(), COL_AFF: smi_train[COL_AFF].astype(float).tolist(), "affinity_class": smi_train["affinity_class"].tolist(), }) smi_val_ds = Dataset.from_dict({ "target_sequence": smi_val["target_sequence"].tolist(), "sequence": smi_val["sequence"].tolist(), "label": smi_val["label"].astype(float).tolist(), "target_embedding": smi_val_tgt_emb, "embedding": smi_val_pooled.astype(np.float32), COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(), COL_AFF: smi_val[COL_AFF].astype(float).tolist(), "affinity_class": smi_val["affinity_class"].tolist(), }) smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds}) smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled" smi_pooled_dd.save_to_disk(str(smi_pooled_out)) log(f"Saved SMILES pooled -> {smi_pooled_out}") # ========================= # WT unpooled paired (ESM target + ESM binder) + save # ========================= with section("WT unpooled paired embeddings + save"): wt_tok_unpooled = wt_tok # reuse tokenizer wt_esm_unpooled = wt_esm # reuse model wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled" wt_unpooled_dd = DatasetDict({ "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train", wt_tok_unpooled, wt_esm_unpooled), "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val", wt_tok_unpooled, wt_esm_unpooled), }) wt_unpooled_dd.save_to_disk(str(wt_unpooled_out)) log(f"Saved WT unpooled -> {wt_unpooled_out}") # ========================= # SMILES unpooled paired (ESM target + PeptideCLM binder) + save # ========================= with section("SMILES unpooled paired embeddings + save"): smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled" smi_unpooled_dd = DatasetDict({ "train": build_smiles_unpooled_paired_dataset( smi_train, smi_unpooled_out / "train", wt_tok, wt_esm, smi_tok, smi_roformer ), "val": build_smiles_unpooled_paired_dataset( smi_val, smi_unpooled_out / "val", wt_tok, wt_esm, smi_tok, smi_roformer ), }) smi_unpooled_dd.save_to_disk(str(smi_unpooled_out)) log(f"Saved SMILES unpooled -> {smi_unpooled_out}") log(f"\n[DONE] All datasets saved under: {OUT_ROOT}") if __name__ == "__main__": main()