import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np import os import pandas as pd from rdkit import Chem, DataStructs from rdkit.Chem import AllChem from rdkit.ML.Cluster import Butina from lightning.pytorch import seed_everything import torch from tqdm import tqdm from transformers import AutoModelForMaskedLM from datasets import Dataset, DatasetDict from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer seed_everything(1986) df = pd.read_csv("caco2.csv") mols = [] canon = [] keep_rows = [] bad = 0 for i, smi in enumerate(df["SMILES"].astype(str)): m = Chem.MolFromSmiles(smi) if m is None: bad += 1 continue smi_can = Chem.MolToSmiles(m, canonical=True, isomericSmiles=True) mols.append(m) canon.append(smi_can) keep_rows.append(i) df = df.iloc[keep_rows].reset_index(drop=True) df["SMILES_CANON"] = canon print(f"Invalid SMILES dropped: {bad} / {len(df) + bad}") # Drop exact duplicate molecules (same canonical smiles) dup_mask = df.duplicated(subset=["SMILES_CANON"], keep="first") df = df.loc[~dup_mask].reset_index(drop=True) mols = [m for m, isdup in zip(mols, dup_mask) if not isdup] # Fingerprints morgan = AllChem.GetMorganGenerator(radius=2, fpSize=2048, includeChirality=True) fps = [morgan.GetFingerprint(m) for m in mols] # Cluster by similarity threshold sim_thresh = 0.6 dist_thresh = 1.0 - sim_thresh dists = [] n = len(fps) for i in range(1, n): sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) dists.extend([1.0 - x for x in sims]) clusters = Butina.ClusterData(dists, nPts=n, distThresh=dist_thresh, isDistData=True) cluster_ids = np.empty(n, dtype=int) for cid, idxs in enumerate(clusters): for idx in idxs: cluster_ids[idx] = cid df["cluster_id"] = cluster_ids # Split by clusters train_fraction = 0.8 rng = np.random.default_rng() unique_clusters = df["cluster_id"].unique() rng.shuffle(unique_clusters) train_target = int(train_fraction * len(df)) train_clusters = set() count = 0 for cid in unique_clusters: csize = (df["cluster_id"] == cid).sum() if count + csize <= train_target: train_clusters.add(cid) count += csize df["split"] = np.where(df["cluster_id"].isin(train_clusters), "train", "val") df[df["split"] == "train"].to_csv("caco2_train.csv", index=False) df[df["split"] == "val"].to_csv("caco2_val.csv", index=False) df.to_csv("caco2_meta_with_split.csv", index=False) print(df["split"].value_counts()) # ====================== # Config # ====================== MAX_LENGTH = 768 BATCH_SIZE = 128 TRAIN_CSV = "caco2_train.csv" VAL_CSV = "caco2_val.csv" SMILES_COL = "SMILES" LABEL_COL = "Caco2" OUT_PATH = "./Classifier_Weight/training_data_cleaned/permeability_caco2/caco2_smiles_with_embeddings" # GPU device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ====================== # Load tokenizer + model # ====================== print("Loading tokenizer and model...") tokenizer = SMILES_SPE_Tokenizer( "./Classifier_Weight/tokenizer/new_vocab.txt", "./Classifier_Weight/tokenizer/new_splits.txt", ) embedding_model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer embedding_model.to(device) embedding_model.eval() HIDDEN_KEY = "last_hidden_state" def get_special_ids(tokenizer): cand = [ getattr(tokenizer, "pad_token_id", None), getattr(tokenizer, "cls_token_id", None), getattr(tokenizer, "sep_token_id", None), getattr(tokenizer, "bos_token_id", None), getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "mask_token_id", None), ] special_ids = sorted({x for x in cand if x is not None}) if len(special_ids) == 0: print("[WARN] No special token ids found on tokenizer; pooling will only exclude padding via attention_mask.") return special_ids SPECIAL_IDS = get_special_ids(tokenizer) SPECIAL_IDS_T = torch.tensor(SPECIAL_IDS, device=device, dtype=torch.long) if len(SPECIAL_IDS) else None @torch.no_grad() def embed_batch_return_both(batch_sequences, max_length, device): tok = tokenizer( batch_sequences, return_tensors="pt", padding=True, max_length=max_length, truncation=True, ) input_ids = tok["input_ids"].to(device) # (B, L) attention_mask = tok["attention_mask"].to(device) # (B, L) outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask) last_hidden = outputs.last_hidden_state # (B, L, H) valid = attention_mask.bool() if SPECIAL_IDS_T is not None and SPECIAL_IDS_T.numel() > 0: valid = valid & (~torch.isin(input_ids, SPECIAL_IDS_T)) # --- pooled embeddings (exclude specials) --- valid_f = valid.unsqueeze(-1).float() # (B, L, 1) summed = torch.sum(last_hidden * valid_f, dim=1) # (B, H) denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) # (B, 1) pooled = (summed / denom).detach().cpu().numpy() # (B, H), float32 # --- unpooled per-example token embeddings (exclude specials) --- token_emb_list = [] mask_list = [] lengths = [] for b in range(last_hidden.shape[0]): emb = last_hidden[b, valid[b]] # (L_i, H) token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) # float16 L_i = emb.shape[0] lengths.append(int(L_i)) mask_list.append(np.ones((L_i,), dtype=np.int8)) return pooled, token_emb_list, mask_list, lengths def generate_embeddings_batched_both(sequences, batch_size, max_length): pooled_all = [] token_emb_all = [] mask_all = [] lengths_all = [] for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding batches"): batch = sequences[i:i + batch_size] pooled, token_list, m_list, lens = embed_batch_return_both(batch, max_length, device) pooled_all.append(pooled) token_emb_all.extend(token_list) mask_all.extend(m_list) lengths_all.extend(lens) pooled_all = np.vstack(pooled_all) # (N, H) return pooled_all, token_emb_all, mask_all, lengths_all from datasets import Dataset, DatasetDict def make_split_datasets(csv_path, split_name): df = pd.read_csv(csv_path) df = df.dropna(subset=[SMILES_COL, LABEL_COL]).reset_index(drop=True) df["sequence"] = df[SMILES_COL].astype(str) labels = pd.to_numeric(df[LABEL_COL], errors="coerce") df = df.loc[~labels.isna()].reset_index(drop=True) sequences = df["sequence"].tolist() labels = pd.to_numeric(df[LABEL_COL], errors="coerce").tolist() # (pooled_embs: (N,H), token_emb_list: list of (L_i,H), mask_list: list of (L_i,), lengths: list[int]) pooled_embs, token_emb_list, mask_list, lengths = generate_embeddings_batched_both( sequences, batch_size=BATCH_SIZE, max_length=MAX_LENGTH ) pooled_ds = Dataset.from_dict({ "sequence": sequences, "label": labels, "embedding": pooled_embs, # (N,H) }) full_ds = Dataset.from_dict({ "sequence": sequences, "label": labels, "embedding": token_emb_list, # each (L_i,H) float16 "attention_mask": mask_list, # each (L_i,) int8 ones "length": lengths, }) return pooled_ds, full_ds train_pooled, train_full = make_split_datasets(TRAIN_CSV, "train") val_pooled, val_full = make_split_datasets(VAL_CSV, "val") ds_pooled = DatasetDict({"train": train_pooled, "val": val_pooled}) ds_full = DatasetDict({"train": train_full, "val": val_full}) ds_pooled.save_to_disk(OUT_PATH) ds_full.save_to_disk(OUT_PATH + "_unpooled")