|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib as mpl |
|
|
import numpy as np |
|
|
import os |
|
|
import pandas as pd |
|
|
from rdkit import Chem, DataStructs |
|
|
from rdkit.Chem import AllChem |
|
|
from rdkit.ML.Cluster import Butina |
|
|
from lightning.pytorch import seed_everything |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoModelForMaskedLM |
|
|
from datasets import Dataset, DatasetDict |
|
|
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
|
|
|
seed_everything(1986) |
|
|
|
|
|
df = pd.read_csv("caco2.csv") |
|
|
|
|
|
mols = [] |
|
|
canon = [] |
|
|
keep_rows = [] |
|
|
bad = 0 |
|
|
|
|
|
for i, smi in enumerate(df["SMILES"].astype(str)): |
|
|
m = Chem.MolFromSmiles(smi) |
|
|
if m is None: |
|
|
bad += 1 |
|
|
continue |
|
|
smi_can = Chem.MolToSmiles(m, canonical=True, isomericSmiles=True) |
|
|
mols.append(m) |
|
|
canon.append(smi_can) |
|
|
keep_rows.append(i) |
|
|
|
|
|
df = df.iloc[keep_rows].reset_index(drop=True) |
|
|
df["SMILES_CANON"] = canon |
|
|
|
|
|
print(f"Invalid SMILES dropped: {bad} / {len(df) + bad}") |
|
|
|
|
|
|
|
|
dup_mask = df.duplicated(subset=["SMILES_CANON"], keep="first") |
|
|
df = df.loc[~dup_mask].reset_index(drop=True) |
|
|
mols = [m for m, isdup in zip(mols, dup_mask) if not isdup] |
|
|
|
|
|
|
|
|
morgan = AllChem.GetMorganGenerator(radius=2, fpSize=2048, includeChirality=True) |
|
|
fps = [morgan.GetFingerprint(m) for m in mols] |
|
|
|
|
|
|
|
|
sim_thresh = 0.6 |
|
|
dist_thresh = 1.0 - sim_thresh |
|
|
|
|
|
dists = [] |
|
|
n = len(fps) |
|
|
for i in range(1, n): |
|
|
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) |
|
|
dists.extend([1.0 - x for x in sims]) |
|
|
|
|
|
clusters = Butina.ClusterData(dists, nPts=n, distThresh=dist_thresh, isDistData=True) |
|
|
|
|
|
cluster_ids = np.empty(n, dtype=int) |
|
|
for cid, idxs in enumerate(clusters): |
|
|
for idx in idxs: |
|
|
cluster_ids[idx] = cid |
|
|
|
|
|
df["cluster_id"] = cluster_ids |
|
|
|
|
|
|
|
|
train_fraction = 0.8 |
|
|
rng = np.random.default_rng() |
|
|
|
|
|
unique_clusters = df["cluster_id"].unique() |
|
|
rng.shuffle(unique_clusters) |
|
|
|
|
|
train_target = int(train_fraction * len(df)) |
|
|
train_clusters = set() |
|
|
count = 0 |
|
|
for cid in unique_clusters: |
|
|
csize = (df["cluster_id"] == cid).sum() |
|
|
if count + csize <= train_target: |
|
|
train_clusters.add(cid) |
|
|
count += csize |
|
|
|
|
|
df["split"] = np.where(df["cluster_id"].isin(train_clusters), "train", "val") |
|
|
|
|
|
df[df["split"] == "train"].to_csv("caco2_train.csv", index=False) |
|
|
df[df["split"] == "val"].to_csv("caco2_val.csv", index=False) |
|
|
df.to_csv("caco2_meta_with_split.csv", index=False) |
|
|
|
|
|
print(df["split"].value_counts()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_LENGTH = 768 |
|
|
BATCH_SIZE = 128 |
|
|
|
|
|
TRAIN_CSV = "caco2_train.csv" |
|
|
VAL_CSV = "caco2_val.csv" |
|
|
|
|
|
SMILES_COL = "SMILES" |
|
|
LABEL_COL = "Caco2" |
|
|
|
|
|
OUT_PATH = "./Classifier_Weight/training_data_cleaned/permeability_caco2/caco2_smiles_with_embeddings" |
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading tokenizer and model...") |
|
|
tokenizer = SMILES_SPE_Tokenizer( |
|
|
"./Classifier_Weight/tokenizer/new_vocab.txt", |
|
|
"./Classifier_Weight/tokenizer/new_splits.txt", |
|
|
) |
|
|
|
|
|
embedding_model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer |
|
|
embedding_model.to(device) |
|
|
embedding_model.eval() |
|
|
|
|
|
HIDDEN_KEY = "last_hidden_state" |
|
|
|
|
|
def get_special_ids(tokenizer): |
|
|
cand = [ |
|
|
getattr(tokenizer, "pad_token_id", None), |
|
|
getattr(tokenizer, "cls_token_id", None), |
|
|
getattr(tokenizer, "sep_token_id", None), |
|
|
getattr(tokenizer, "bos_token_id", None), |
|
|
getattr(tokenizer, "eos_token_id", None), |
|
|
getattr(tokenizer, "mask_token_id", None), |
|
|
] |
|
|
special_ids = sorted({x for x in cand if x is not None}) |
|
|
if len(special_ids) == 0: |
|
|
print("[WARN] No special token ids found on tokenizer; pooling will only exclude padding via attention_mask.") |
|
|
return special_ids |
|
|
|
|
|
SPECIAL_IDS = get_special_ids(tokenizer) |
|
|
SPECIAL_IDS_T = torch.tensor(SPECIAL_IDS, device=device, dtype=torch.long) if len(SPECIAL_IDS) else None |
|
|
|
|
|
@torch.no_grad() |
|
|
def embed_batch_return_both(batch_sequences, max_length, device): |
|
|
tok = tokenizer( |
|
|
batch_sequences, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
max_length=max_length, |
|
|
truncation=True, |
|
|
) |
|
|
input_ids = tok["input_ids"].to(device) |
|
|
attention_mask = tok["attention_mask"].to(device) |
|
|
|
|
|
outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
last_hidden = outputs.last_hidden_state |
|
|
|
|
|
valid = attention_mask.bool() |
|
|
if SPECIAL_IDS_T is not None and SPECIAL_IDS_T.numel() > 0: |
|
|
valid = valid & (~torch.isin(input_ids, SPECIAL_IDS_T)) |
|
|
|
|
|
|
|
|
valid_f = valid.unsqueeze(-1).float() |
|
|
summed = torch.sum(last_hidden * valid_f, dim=1) |
|
|
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) |
|
|
pooled = (summed / denom).detach().cpu().numpy() |
|
|
|
|
|
|
|
|
token_emb_list = [] |
|
|
mask_list = [] |
|
|
lengths = [] |
|
|
for b in range(last_hidden.shape[0]): |
|
|
emb = last_hidden[b, valid[b]] |
|
|
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) |
|
|
L_i = emb.shape[0] |
|
|
lengths.append(int(L_i)) |
|
|
mask_list.append(np.ones((L_i,), dtype=np.int8)) |
|
|
|
|
|
return pooled, token_emb_list, mask_list, lengths |
|
|
|
|
|
def generate_embeddings_batched_both(sequences, batch_size, max_length): |
|
|
pooled_all = [] |
|
|
token_emb_all = [] |
|
|
mask_all = [] |
|
|
lengths_all = [] |
|
|
|
|
|
for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding batches"): |
|
|
batch = sequences[i:i + batch_size] |
|
|
pooled, token_list, m_list, lens = embed_batch_return_both(batch, max_length, device) |
|
|
pooled_all.append(pooled) |
|
|
token_emb_all.extend(token_list) |
|
|
mask_all.extend(m_list) |
|
|
lengths_all.extend(lens) |
|
|
|
|
|
pooled_all = np.vstack(pooled_all) |
|
|
return pooled_all, token_emb_all, mask_all, lengths_all |
|
|
|
|
|
from datasets import Dataset, DatasetDict |
|
|
|
|
|
def make_split_datasets(csv_path, split_name): |
|
|
df = pd.read_csv(csv_path) |
|
|
df = df.dropna(subset=[SMILES_COL, LABEL_COL]).reset_index(drop=True) |
|
|
df["sequence"] = df[SMILES_COL].astype(str) |
|
|
|
|
|
labels = pd.to_numeric(df[LABEL_COL], errors="coerce") |
|
|
df = df.loc[~labels.isna()].reset_index(drop=True) |
|
|
sequences = df["sequence"].tolist() |
|
|
labels = pd.to_numeric(df[LABEL_COL], errors="coerce").tolist() |
|
|
|
|
|
|
|
|
pooled_embs, token_emb_list, mask_list, lengths = generate_embeddings_batched_both( |
|
|
sequences, batch_size=BATCH_SIZE, max_length=MAX_LENGTH |
|
|
) |
|
|
|
|
|
pooled_ds = Dataset.from_dict({ |
|
|
"sequence": sequences, |
|
|
"label": labels, |
|
|
"embedding": pooled_embs, |
|
|
}) |
|
|
|
|
|
full_ds = Dataset.from_dict({ |
|
|
"sequence": sequences, |
|
|
"label": labels, |
|
|
"embedding": token_emb_list, |
|
|
"attention_mask": mask_list, |
|
|
"length": lengths, |
|
|
}) |
|
|
|
|
|
return pooled_ds, full_ds |
|
|
|
|
|
train_pooled, train_full = make_split_datasets(TRAIN_CSV, "train") |
|
|
val_pooled, val_full = make_split_datasets(VAL_CSV, "val") |
|
|
|
|
|
ds_pooled = DatasetDict({"train": train_pooled, "val": val_pooled}) |
|
|
ds_full = DatasetDict({"train": train_full, "val": val_full}) |
|
|
|
|
|
ds_pooled.save_to_disk(OUT_PATH) |
|
|
ds_full.save_to_disk(OUT_PATH + "_unpooled") |
|
|
|