|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib as mpl |
|
|
import numpy as np |
|
|
import os |
|
|
import pandas as pd |
|
|
from lightning.pytorch import seed_everything |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from datasets import Dataset, DatasetDict, Features, Value, Sequence |
|
|
from transformers import AutoModelForMaskedLM |
|
|
import sys |
|
|
from transformers import AutoTokenizer, EsmModel |
|
|
from datasets import Dataset, DatasetDict |
|
|
import tqdm |
|
|
|
|
|
seed_everything(1986) |
|
|
|
|
|
|
|
|
|
|
|
m1 = [ |
|
|
'[PAD]','A','R','N','D','C','Q','E','G','H', |
|
|
'I','L','K','M','F','P','S','T','W','Y','V' |
|
|
] |
|
|
m2 = dict(zip( |
|
|
['[PAD]','[UNK]','[CLS]','[SEP]','[MASK]','L', |
|
|
'A','G','V','E','S','I','K','R','D','T','P','N', |
|
|
'Q','F','Y','M','H','C','W','X','U','B','Z','O'], |
|
|
range(30) |
|
|
)) |
|
|
|
|
|
reverse_m2 = {v: k for k, v in m2.items()} |
|
|
sequences = [] |
|
|
labels = [] |
|
|
|
|
|
|
|
|
print("Processing positive sequences...") |
|
|
with np.load('nf-positive.npz') as pos: |
|
|
pos_data = pos['arr_0'] |
|
|
for seq in pos_data: |
|
|
sequence = ''.join(reverse_m2[token] for token in seq if token != 0) |
|
|
sequences.append(sequence) |
|
|
labels.append(1) |
|
|
|
|
|
|
|
|
print("Processing negative sequences...") |
|
|
with np.load('nf-negative.npz') as neg: |
|
|
neg_data = neg['arr_0'] |
|
|
for seq in neg_data: |
|
|
sequence = ''.join(reverse_m2[token] for token in seq if token != 0) |
|
|
sequences.append(sequence) |
|
|
labels.append(0) |
|
|
|
|
|
|
|
|
ids = [f"seq_{i:06d}" for i in range(len(sequences))] |
|
|
df = pd.DataFrame({ |
|
|
"id": ids, |
|
|
"sequence": sequences, |
|
|
"label": labels, |
|
|
}) |
|
|
print("Before dedup:", len(df)) |
|
|
|
|
|
df = ( |
|
|
df |
|
|
.drop_duplicates(subset=["sequence"]) |
|
|
.reset_index(drop=True) |
|
|
) |
|
|
|
|
|
print("After dedup:", len(df)) |
|
|
|
|
|
df.to_csv("nf_all.csv", index=False) |
|
|
print("Saved nf_all.csv") |
|
|
|
|
|
|
|
|
with open("nf_all.fasta", "w") as f: |
|
|
for seq_id, seq in zip(df["id"], df["sequence"]): |
|
|
f.write(f">{seq_id}\n{seq}\n") |
|
|
|
|
|
print("Saved nf_all.fasta") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
mkdir -p mmseqs_tmp |
|
|
|
|
|
mmseqs createdb nf_all.fasta nfDB |
|
|
|
|
|
mmseqs cluster nfDB nfDB_clu mmseqs_tmp \ |
|
|
--min-seq-id 0.3 -c 0.8 --cov-mode 0 |
|
|
|
|
|
mmseqs createtsv nfDB nfDB nfDB_clu clusters-nf.tsv |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_fraction = 0.8 |
|
|
csv_path = "nf_all.csv" |
|
|
clusters_tsv = "clusters-nf.tsv" |
|
|
rng = np.random.default_rng() |
|
|
|
|
|
df = pd.read_csv(csv_path) |
|
|
|
|
|
|
|
|
id_to_index = {sid: i for i, sid in enumerate(df["id"])} |
|
|
|
|
|
|
|
|
cluster_map = {} |
|
|
with open(clusters_tsv) as f: |
|
|
for line in f: |
|
|
if not line.strip(): |
|
|
continue |
|
|
rep_id, member_id = line.strip().split('\t') |
|
|
cluster_map[member_id] = rep_id |
|
|
|
|
|
|
|
|
for sid in df["id"]: |
|
|
if sid not in cluster_map: |
|
|
cluster_map[sid] = sid |
|
|
|
|
|
|
|
|
cluster_to_indices = {} |
|
|
for sid, cid in cluster_map.items(): |
|
|
idx = id_to_index[sid] |
|
|
cluster_to_indices.setdefault(cid, []).append(idx) |
|
|
|
|
|
|
|
|
cluster_ids = list(cluster_to_indices.keys()) |
|
|
rng.shuffle(cluster_ids) |
|
|
|
|
|
|
|
|
total_n = len(df) |
|
|
train_target = int(train_fraction * total_n) |
|
|
|
|
|
train_indices = [] |
|
|
val_indices = [] |
|
|
current_train = 0 |
|
|
|
|
|
for cid in cluster_ids: |
|
|
indices = cluster_to_indices[cid] |
|
|
if current_train + len(indices) <= train_target: |
|
|
train_indices.extend(indices) |
|
|
current_train += len(indices) |
|
|
else: |
|
|
val_indices.extend(indices) |
|
|
|
|
|
|
|
|
split = np.full(total_n, "val", dtype=object) |
|
|
split[train_indices] = "train" |
|
|
|
|
|
|
|
|
df_with_split = df.copy() |
|
|
df_with_split["split"] = split |
|
|
df_with_split.to_csv("nf_meta_with_split.csv", index=False) |
|
|
|
|
|
|
|
|
df_train = df_with_split[df_with_split["split"] == "train"].reset_index(drop=True) |
|
|
df_val = df_with_split[df_with_split["split"] == "val"].reset_index(drop=True) |
|
|
|
|
|
df_train.to_csv("nf_train.csv", index=False) |
|
|
df_val.to_csv("nf_val.csv", index=False) |
|
|
|
|
|
|
|
|
print("Split counts:") |
|
|
print(df_with_split["split"].value_counts()) |
|
|
print() |
|
|
print(f"Train size: {len(df_train)}") |
|
|
print(f"Val size: {len(df_val)}") |
|
|
print("Wrote:") |
|
|
print(" - sol_meta_with_split.csv") |
|
|
print(" - sol_train.csv") |
|
|
print(" - sol_val.csv") |
|
|
|
|
|
|
|
|
device = torch.device("cuda:0") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" |
|
|
save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def compute_embeddings(sequences, batch_size=32): |
|
|
"""Return numpy array of shape (N, hidden_dim).""" |
|
|
embeddings = [] |
|
|
for i in tqdm.trange(0, len(sequences), batch_size): |
|
|
batch_sequences = sequences[i:i + batch_size] |
|
|
|
|
|
inputs = tokenizer( |
|
|
batch_sequences, |
|
|
padding=True, |
|
|
max_length=1022, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
last_hidden_states = outputs.last_hidden_state |
|
|
|
|
|
attention_mask = inputs["attention_mask"].unsqueeze(-1) |
|
|
masked_hidden_states = last_hidden_states * attention_mask |
|
|
sum_hidden_states = masked_hidden_states.sum(dim=1) |
|
|
seq_lengths = attention_mask.sum(dim=1) |
|
|
batch_embeddings = sum_hidden_states / seq_lengths |
|
|
|
|
|
embeddings.append(batch_embeddings.cpu()) |
|
|
|
|
|
return torch.cat(embeddings, dim=0).numpy() |
|
|
|
|
|
|
|
|
def create_and_save_datasets(): |
|
|
|
|
|
meta = pd.read_csv(meta_path) |
|
|
sequences = meta["sequence"].tolist() |
|
|
labels = meta["label"].tolist() |
|
|
splits = meta["split"].tolist() |
|
|
|
|
|
print(f"Total sequences: {len(sequences)}") |
|
|
print("Split counts:", pd.Series(splits).value_counts().to_dict()) |
|
|
|
|
|
print("Computing ESM embeddings...") |
|
|
embeddings = compute_embeddings(sequences) |
|
|
|
|
|
full_ds = Dataset.from_dict({ |
|
|
"sequence": sequences, |
|
|
"embedding": embeddings, |
|
|
"label": labels, |
|
|
"split": splits, |
|
|
}) |
|
|
|
|
|
|
|
|
train_ds = full_ds.filter(lambda x: x["split"] == "train") |
|
|
val_ds = full_ds.filter(lambda x: x["split"] == "val") |
|
|
|
|
|
train_ds = train_ds.remove_columns("split") |
|
|
val_ds = val_ds.remove_columns("split") |
|
|
|
|
|
ds_dict = DatasetDict({ |
|
|
"train": train_ds, |
|
|
"val": val_ds, |
|
|
}) |
|
|
|
|
|
ds_dict.save_to_disk(save_path) |
|
|
print(f"Saved DatasetDict with train/val to: {save_path}") |
|
|
print("Train size:", len(ds_dict["train"])) |
|
|
print("Val size:", len(ds_dict["val"])) |
|
|
|
|
|
return ds_dict |
|
|
|
|
|
|
|
|
ds = create_and_save_datasets() |
|
|
|
|
|
ex = ds["train"][0] |
|
|
print("\nExample from train:") |
|
|
print("Sequence:", ex["sequence"]) |
|
|
print("Embedding shape:", np.array(ex["embedding"]).shape) |
|
|
print("Label:", ex["label"]) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" |
|
|
save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings_unpooled" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D", add_pooling_layer=False).to(device).eval() |
|
|
|
|
|
cls_id = tokenizer.cls_token_id |
|
|
eos_id = tokenizer.eos_token_id |
|
|
|
|
|
@torch.no_grad() |
|
|
def embed_one(seq, max_length=1022): |
|
|
inputs = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
out = model(**inputs) |
|
|
h = out.last_hidden_state[0] |
|
|
attn = inputs["attention_mask"][0].bool() |
|
|
ids = inputs["input_ids"][0] |
|
|
|
|
|
keep = attn.clone() |
|
|
if cls_id is not None: |
|
|
keep &= (ids != cls_id) |
|
|
if eos_id is not None: |
|
|
keep &= (ids != eos_id) |
|
|
|
|
|
hb = h[keep].detach().cpu().to(torch.float16).numpy() |
|
|
return hb |
|
|
|
|
|
H = 1280 |
|
|
features = Features({ |
|
|
"sequence": Value("string"), |
|
|
"label": Value("int64"), |
|
|
"embedding": Sequence(Sequence(Value("float16"), length=H)), |
|
|
"attention_mask": Sequence(Value("int8")), |
|
|
"length": Value("int64"), |
|
|
}) |
|
|
|
|
|
def make_generator(df): |
|
|
for seq, lab in tqdm.tqdm(zip(df["sequence"].tolist(), df["label"].astype(int).tolist()), total=len(df)): |
|
|
emb = embed_one(seq) |
|
|
emb_list = emb.tolist() |
|
|
li = len(emb_list) |
|
|
yield { |
|
|
"sequence": seq, |
|
|
"label": int(lab), |
|
|
"embedding": emb_list, |
|
|
"attention_mask": [1] * li, |
|
|
"length": li, |
|
|
} |
|
|
|
|
|
def build_and_save_split(df, out_dir): |
|
|
ds = Dataset.from_generator(make_generator, gen_kwargs={"df": df}, features=features) |
|
|
|
|
|
ds.save_to_disk(out_dir, max_shard_size="1GB") |
|
|
return ds |
|
|
|
|
|
meta = pd.read_csv(meta_path) |
|
|
train_df = meta[meta["split"] == "train"].reset_index(drop=True) |
|
|
val_df = meta[meta["split"] == "val"].reset_index(drop=True) |
|
|
|
|
|
train_dir = os.path.join(save_path, "train") |
|
|
val_dir = os.path.join(save_path, "val") |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
train_ds = build_and_save_split(train_df, train_dir) |
|
|
val_ds = build_and_save_split(val_df, val_dir) |
|
|
|
|
|
ds_dict = DatasetDict({"train": train_ds, "val": val_ds}) |
|
|
ds_dict.save_to_disk(save_path) |
|
|
print(ds_dict) |
|
|
|