# non fouling as example import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np import os import pandas as pd from lightning.pytorch import seed_everything import torch from tqdm import tqdm from datasets import Dataset, DatasetDict, Features, Value, Sequence from transformers import AutoModelForMaskedLM import sys from transformers import AutoTokenizer, EsmModel from datasets import Dataset, DatasetDict import tqdm seed_everything(1986) # ------------------------- # Process from source # ------------------------- m1 = [ '[PAD]','A','R','N','D','C','Q','E','G','H', 'I','L','K','M','F','P','S','T','W','Y','V' ] m2 = dict(zip( ['[PAD]','[UNK]','[CLS]','[SEP]','[MASK]','L', 'A','G','V','E','S','I','K','R','D','T','P','N', 'Q','F','Y','M','H','C','W','X','U','B','Z','O'], range(30) )) # Create reverse mapping reverse_m2 = {v: k for k, v in m2.items()} sequences = [] labels = [] # Load and process positive sequences print("Processing positive sequences...") with np.load('nf-positive.npz') as pos: pos_data = pos['arr_0'] for seq in pos_data: sequence = ''.join(reverse_m2[token] for token in seq if token != 0) sequences.append(sequence) labels.append(1) # Load and process negative sequences print("Processing negative sequences...") with np.load('nf-negative.npz') as neg: neg_data = neg['arr_0'] for seq in neg_data: sequence = ''.join(reverse_m2[token] for token in seq if token != 0) sequences.append(sequence) labels.append(0) # Build a DataFrame and add stable IDs ids = [f"seq_{i:06d}" for i in range(len(sequences))] df = pd.DataFrame({ "id": ids, "sequence": sequences, "label": labels, }) print("Before dedup:", len(df)) df = ( df .drop_duplicates(subset=["sequence"]) # keep first occurrence .reset_index(drop=True) ) print("After dedup:", len(df)) # Save to CSV df.to_csv("nf_all.csv", index=False) print("Saved nf_all.csv") # Save as FASTA for MMseqs with open("nf_all.fasta", "w") as f: for seq_id, seq in zip(df["id"], df["sequence"]): f.write(f">{seq_id}\n{seq}\n") print("Saved nf_all.fasta") # ------------------------- # RUN MMSEQS IN TERMINAL # ------------------------- # ------------------------- """ mkdir -p mmseqs_tmp mmseqs createdb nf_all.fasta nfDB mmseqs cluster nfDB nfDB_clu mmseqs_tmp \ --min-seq-id 0.3 -c 0.8 --cov-mode 0 mmseqs createtsv nfDB nfDB nfDB_clu clusters-nf.tsv """ # ------------------------- # ------------------------- # Split based on clusters # ------------------------- train_fraction = 0.8 csv_path = "nf_all.csv" clusters_tsv = "clusters-nf.tsv" rng = np.random.default_rng() df = pd.read_csv(csv_path) # must contain: id, sequence, label # Map id -> index id_to_index = {sid: i for i, sid in enumerate(df["id"])} # Read MMseqs clusters cluster_map = {} # member_id -> cluster_id (rep_id) with open(clusters_tsv) as f: for line in f: if not line.strip(): continue rep_id, member_id = line.strip().split('\t') cluster_map[member_id] = rep_id # Handle singleton sequences (not clustered) for sid in df["id"]: if sid not in cluster_map: cluster_map[sid] = sid # Invert to cluster_id -> dataset indices cluster_to_indices = {} for sid, cid in cluster_map.items(): idx = id_to_index[sid] cluster_to_indices.setdefault(cid, []).append(idx) # Shuffle clusters cluster_ids = list(cluster_to_indices.keys()) rng.shuffle(cluster_ids) # Assign clusters to splits total_n = len(df) train_target = int(train_fraction * total_n) train_indices = [] val_indices = [] current_train = 0 for cid in cluster_ids: indices = cluster_to_indices[cid] if current_train + len(indices) <= train_target: train_indices.extend(indices) current_train += len(indices) else: val_indices.extend(indices) # Create split column split = np.full(total_n, "val", dtype=object) split[train_indices] = "train" # ===== Master CSV with split ===== df_with_split = df.copy() df_with_split["split"] = split df_with_split.to_csv("nf_meta_with_split.csv", index=False) # ===== Other CSVs ===== df_train = df_with_split[df_with_split["split"] == "train"].reset_index(drop=True) df_val = df_with_split[df_with_split["split"] == "val"].reset_index(drop=True) df_train.to_csv("nf_train.csv", index=False) df_val.to_csv("nf_val.csv", index=False) # ===== Quick sanity output ===== print("Split counts:") print(df_with_split["split"].value_counts()) print() print(f"Train size: {len(df_train)}") print(f"Val size: {len(df_val)}") print("Wrote:") print(" - sol_meta_with_split.csv") print(" - sol_train.csv") print(" - sol_val.csv") device = torch.device("cuda:0") print(f"Using device: {device}") meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings" tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") model = model.to(device) model.eval() def compute_embeddings(sequences, batch_size=32): """Return numpy array of shape (N, hidden_dim).""" embeddings = [] for i in tqdm.trange(0, len(sequences), batch_size): batch_sequences = sequences[i:i + batch_size] inputs = tokenizer( batch_sequences, padding=True, max_length=1022, truncation=True, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) last_hidden_states = outputs.last_hidden_state # (B, L, H) attention_mask = inputs["attention_mask"].unsqueeze(-1) masked_hidden_states = last_hidden_states * attention_mask sum_hidden_states = masked_hidden_states.sum(dim=1) seq_lengths = attention_mask.sum(dim=1) batch_embeddings = sum_hidden_states / seq_lengths # (B, H) embeddings.append(batch_embeddings.cpu()) return torch.cat(embeddings, dim=0).numpy() def create_and_save_datasets(): # Load sequences + labels + splits meta = pd.read_csv(meta_path) sequences = meta["sequence"].tolist() labels = meta["label"].tolist() splits = meta["split"].tolist() print(f"Total sequences: {len(sequences)}") print("Split counts:", pd.Series(splits).value_counts().to_dict()) print("Computing ESM embeddings...") embeddings = compute_embeddings(sequences) # shape (N, H) full_ds = Dataset.from_dict({ "sequence": sequences, "embedding": embeddings, "label": labels, "split": splits, }) # Split into train / val, then drop the 'split' column train_ds = full_ds.filter(lambda x: x["split"] == "train") val_ds = full_ds.filter(lambda x: x["split"] == "val") train_ds = train_ds.remove_columns("split") val_ds = val_ds.remove_columns("split") ds_dict = DatasetDict({ "train": train_ds, "val": val_ds, }) ds_dict.save_to_disk(save_path) print(f"Saved DatasetDict with train/val to: {save_path}") print("Train size:", len(ds_dict["train"])) print("Val size:", len(ds_dict["val"])) return ds_dict ds = create_and_save_datasets() ex = ds["train"][0] print("\nExample from train:") print("Sequence:", ex["sequence"]) print("Embedding shape:", np.array(ex["embedding"]).shape) print("Label:", ex["label"]) torch.cuda.empty_cache() meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings_unpooled" tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D", add_pooling_layer=False).to(device).eval() cls_id = tokenizer.cls_token_id eos_id = tokenizer.eos_token_id @torch.no_grad() def embed_one(seq, max_length=1022): inputs = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} out = model(**inputs) h = out.last_hidden_state[0] # (L, H) attn = inputs["attention_mask"][0].bool() # (L,) ids = inputs["input_ids"][0] keep = attn.clone() if cls_id is not None: keep &= (ids != cls_id) if eos_id is not None: keep &= (ids != eos_id) hb = h[keep].detach().cpu().to(torch.float16).numpy() # (Li, H) return hb H = 1280 features = Features({ "sequence": Value("string"), "label": Value("int64"), "embedding": Sequence(Sequence(Value("float16"), length=H)), # (Li, H) as nested lists "attention_mask": Sequence(Value("int8")), # (Li,) "length": Value("int64"), }) def make_generator(df): for seq, lab in tqdm.tqdm(zip(df["sequence"].tolist(), df["label"].astype(int).tolist()), total=len(df)): emb = embed_one(seq) # (Li, H) float16 emb_list = emb.tolist() li = len(emb_list) yield { "sequence": seq, "label": int(lab), "embedding": emb_list, "attention_mask": [1] * li, "length": li, } def build_and_save_split(df, out_dir): ds = Dataset.from_generator(make_generator, gen_kwargs={"df": df}, features=features) # small batches when writing prevents the 2GB offset overflow ds.save_to_disk(out_dir, max_shard_size="1GB") return ds meta = pd.read_csv(meta_path) train_df = meta[meta["split"] == "train"].reset_index(drop=True) val_df = meta[meta["split"] == "val"].reset_index(drop=True) train_dir = os.path.join(save_path, "train") val_dir = os.path.join(save_path, "val") os.makedirs(save_path, exist_ok=True) train_ds = build_and_save_split(train_df, train_dir) val_ds = build_and_save_split(val_df, val_dir) ds_dict = DatasetDict({"train": train_ds, "val": val_ds}) ds_dict.save_to_disk(save_path) print(ds_dict)