|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from datasets import load_from_disk, DatasetDict |
|
|
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score |
|
|
import torch.nn as nn |
|
|
import optuna |
|
|
import os |
|
|
from typing import Dict, Any, Tuple, Optional |
|
|
import matplotlib.pyplot as plt |
|
|
from sklearn.metrics import ( |
|
|
f1_score, roc_auc_score, average_precision_score, |
|
|
precision_recall_curve, roc_curve |
|
|
) |
|
|
import json |
|
|
import joblib |
|
|
import pandas as pd |
|
|
import time |
|
|
from lightning.pytorch import seed_everything |
|
|
seed_everything(1986) |
|
|
|
|
|
def infer_in_dim_from_unpooled_ds(ds) -> int: |
|
|
ex = ds[0] |
|
|
|
|
|
return int(len(ex["embedding"][0])) |
|
|
|
|
|
def load_split(dataset_path): |
|
|
ds = load_from_disk(dataset_path) |
|
|
|
|
|
if isinstance(ds, DatasetDict): |
|
|
return ds["train"], ds["val"] |
|
|
|
|
|
raise ValueError("Expected DatasetDict with 'train' and 'val' splits") |
|
|
|
|
|
def collate_unpooled(batch): |
|
|
|
|
|
lengths = [int(x["length"]) for x in batch] |
|
|
Lmax = max(lengths) |
|
|
H = len(batch[0]["embedding"][0]) |
|
|
|
|
|
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32) |
|
|
M = torch.zeros(len(batch), Lmax, dtype=torch.bool) |
|
|
y = torch.tensor([x["label"] for x in batch], dtype=torch.float32) |
|
|
|
|
|
for i, x in enumerate(batch): |
|
|
emb = torch.tensor(x["embedding"], dtype=torch.float32) |
|
|
L = emb.shape[0] |
|
|
X[i, :L] = emb |
|
|
if "attention_mask" in x: |
|
|
m = torch.tensor(x["attention_mask"], dtype=torch.bool) |
|
|
M[i, :L] = m[:L] |
|
|
else: |
|
|
M[i, :L] = True |
|
|
|
|
|
return X, M, y |
|
|
|
|
|
|
|
|
def save_predictions_csv( |
|
|
out_dir: str, |
|
|
split_name: str, |
|
|
y_true: np.ndarray, |
|
|
y_prob: np.ndarray, |
|
|
threshold: float, |
|
|
sequences: Optional[np.ndarray] = None, |
|
|
): |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
df = pd.DataFrame({ |
|
|
"y_true": y_true.astype(int), |
|
|
"y_prob": y_prob.astype(float), |
|
|
"y_pred": (y_prob >= threshold).astype(int), |
|
|
}) |
|
|
if sequences is not None: |
|
|
df.insert(0, "sequence", sequences) |
|
|
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False) |
|
|
|
|
|
|
|
|
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray): |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
precision, recall, _ = precision_recall_curve(y_true, y_prob) |
|
|
plt.figure() |
|
|
plt.plot(recall, precision) |
|
|
plt.xlabel("Recall") |
|
|
plt.ylabel("Precision") |
|
|
plt.title("Precision-Recall Curve") |
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(out_dir, "pr_curve.png")) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
fpr, tpr, _ = roc_curve(y_true, y_prob) |
|
|
plt.figure() |
|
|
plt.plot(fpr, tpr) |
|
|
plt.xlabel("False Positive Rate") |
|
|
plt.ylabel("True Positive Rate") |
|
|
plt.title("ROC Curve") |
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(out_dir, "roc_curve.png")) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def best_f1_threshold(y_true, y_prob): |
|
|
p, r, thr = precision_recall_curve(y_true, y_prob) |
|
|
f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12) |
|
|
i = int(np.nanargmax(f1s)) |
|
|
return float(thr[i]), float(f1s[i]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def eval_probs(model, loader, device): |
|
|
model.eval() |
|
|
ys, ps = [], [] |
|
|
for X, M, y in loader: |
|
|
X, M = X.to(device), M.to(device) |
|
|
logits = model(X, M) |
|
|
prob = torch.sigmoid(logits).detach().cpu().numpy() |
|
|
ys.append(y.numpy()) |
|
|
ps.append(prob) |
|
|
return np.concatenate(ys), np.concatenate(ps) |
|
|
|
|
|
def train_one_epoch(model, loader, optim, criterion, device): |
|
|
model.train() |
|
|
for X, M, y in loader: |
|
|
X, M, y = X.to(device), M.to(device), y.to(device) |
|
|
optim.zero_grad(set_to_none=True) |
|
|
logits = model(X, M) |
|
|
loss = criterion(logits, y) |
|
|
loss.backward() |
|
|
optim.step() |
|
|
|
|
|
|
|
|
|
|
|
class MaskedMeanPool(nn.Module): |
|
|
def forward(self, X, M): |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
return (X * Mf).sum(dim=1) / denom |
|
|
|
|
|
class MLPClassifier(nn.Module): |
|
|
def __init__(self, in_dim, hidden=512, dropout=0.1): |
|
|
super().__init__() |
|
|
self.pool = MaskedMeanPool() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(in_dim, hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden, 1), |
|
|
) |
|
|
def forward(self, X, M): |
|
|
z = self.pool(X, M) |
|
|
return self.net(z).squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
class CNNClassifier(nn.Module): |
|
|
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): |
|
|
super().__init__() |
|
|
blocks = [] |
|
|
ch = in_ch |
|
|
for _ in range(layers): |
|
|
blocks += [ |
|
|
nn.Conv1d(ch, c, kernel_size=k, padding=k//2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
] |
|
|
ch = c |
|
|
self.conv = nn.Sequential(*blocks) |
|
|
self.head = nn.Linear(c, 1) |
|
|
|
|
|
def forward(self, X, M): |
|
|
|
|
|
Xc = X.transpose(1, 2) |
|
|
Y = self.conv(Xc).transpose(1, 2) |
|
|
|
|
|
|
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
pooled = (Y * Mf).sum(dim=1) / denom |
|
|
return self.head(pooled).squeeze(-1) |
|
|
|
|
|
|
|
|
class TransformerClassifier(nn.Module): |
|
|
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): |
|
|
super().__init__() |
|
|
self.proj = nn.Linear(in_dim, d_model) |
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, dim_feedforward=ff, |
|
|
dropout=dropout, batch_first=True, activation="gelu" |
|
|
) |
|
|
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) |
|
|
self.head = nn.Linear(d_model, 1) |
|
|
|
|
|
def forward(self, X, M): |
|
|
|
|
|
pad_mask = ~M |
|
|
Z = self.proj(X) |
|
|
Z = self.enc(Z, src_key_padding_mask=pad_mask) |
|
|
|
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
pooled = (Z * Mf).sum(dim=1) / denom |
|
|
return self.head(pooled).squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"): |
|
|
|
|
|
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) |
|
|
wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True) |
|
|
dropout = trial.suggest_float("dropout", 0.0, 0.5) |
|
|
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64]) |
|
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, |
|
|
collate_fn=collate_unpooled, num_workers=4, pin_memory=True) |
|
|
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, |
|
|
collate_fn=collate_unpooled, num_workers=4, pin_memory=True) |
|
|
|
|
|
in_dim = infer_in_dim_from_unpooled_ds(train_ds) |
|
|
|
|
|
if model_name == "mlp": |
|
|
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048]) |
|
|
model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout) |
|
|
elif model_name == "cnn": |
|
|
c = trial.suggest_categorical("channels", [128, 256, 512]) |
|
|
k = trial.suggest_categorical("kernel", [3, 5, 7]) |
|
|
layers = trial.suggest_int("layers", 1, 4) |
|
|
model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout) |
|
|
elif model_name == "transformer": |
|
|
d = trial.suggest_categorical("d_model", [128, 256, 384]) |
|
|
nhead = trial.suggest_categorical("nhead", [4, 8]) |
|
|
layers = trial.suggest_int("layers", 1, 4) |
|
|
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536]) |
|
|
model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout) |
|
|
else: |
|
|
raise ValueError(model_name) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
ytr = np.asarray(train_ds["label"], dtype=np.int64) |
|
|
pos = ytr.sum() |
|
|
neg = len(ytr) - pos |
|
|
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32) |
|
|
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
|
|
|
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
|
|
|
|
|
best_f1 = -1.0 |
|
|
patience = 8 |
|
|
bad = 0 |
|
|
|
|
|
for epoch in range(1, 51): |
|
|
train_one_epoch(model, train_loader, optim, criterion, device) |
|
|
|
|
|
y_true, y_prob = eval_probs(model, val_loader, device) |
|
|
auc = roc_auc_score(y_true, y_prob) |
|
|
|
|
|
thr, f1 = best_f1_threshold(y_true, y_prob) |
|
|
|
|
|
trial.set_user_attr("val_auc", float(auc)) |
|
|
trial.set_user_attr("val_f1", float(f1)) |
|
|
trial.set_user_attr("val_thr", float(thr)) |
|
|
|
|
|
|
|
|
trial.report(f1, epoch) |
|
|
if trial.should_prune(): |
|
|
raise optuna.TrialPruned() |
|
|
|
|
|
if f1 > best_f1 + 1e-4: |
|
|
best_f1 = f1 |
|
|
bad = 0 |
|
|
else: |
|
|
bad += 1 |
|
|
if bad >= patience: |
|
|
break |
|
|
|
|
|
return best_f1 |
|
|
|
|
|
def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"): |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
train_ds, val_ds = load_split(dataset_path) |
|
|
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}") |
|
|
|
|
|
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
|
|
study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials) |
|
|
|
|
|
trials_df = study.trials_dataframe() |
|
|
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False) |
|
|
|
|
|
best = study.best_trial |
|
|
best_params = dict(best.params) |
|
|
best_f1_optuna = float(best.value) |
|
|
best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan)) |
|
|
best_thr = float(best.user_attrs.get("val_thr", 0.5)) |
|
|
|
|
|
in_dim = infer_in_dim_from_unpooled_ds(train_ds) |
|
|
|
|
|
|
|
|
batch_size = int(best_params.get("batch_size", 32)) |
|
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, |
|
|
collate_fn=collate_unpooled, num_workers=4, pin_memory=True) |
|
|
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, |
|
|
collate_fn=collate_unpooled, num_workers=4, pin_memory=True) |
|
|
|
|
|
|
|
|
dropout = float(best_params.get("dropout", 0.1)) |
|
|
if model_name == "mlp": |
|
|
model = MLPClassifier( |
|
|
in_dim=in_dim, |
|
|
hidden=int(best_params["hidden"]), |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
elif model_name == "cnn": |
|
|
model = CNNClassifier( |
|
|
in_ch=in_dim, |
|
|
c=int(best_params["channels"]), |
|
|
k=int(best_params["kernel"]), |
|
|
layers=int(best_params["layers"]), |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
elif model_name == "transformer": |
|
|
model = TransformerClassifier( |
|
|
in_dim=in_dim, |
|
|
d_model=int(best_params["d_model"]), |
|
|
nhead=int(best_params["nhead"]), |
|
|
layers=int(best_params["layers"]), |
|
|
ff=int(best_params["ff"]), |
|
|
dropout=dropout, |
|
|
) |
|
|
else: |
|
|
raise ValueError(model_name) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
ytr = np.asarray(train_ds["label"], dtype=np.int64) |
|
|
pos = ytr.sum() |
|
|
neg = len(ytr) - pos |
|
|
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32) |
|
|
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
|
|
|
lr = float(best_params["lr"]) |
|
|
wd = float(best_params["weight_decay"]) |
|
|
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
|
|
|
|
|
|
|
|
best_f1_seen, bad, patience = -1.0, 0, 12 |
|
|
best_state = None |
|
|
best_thr_seen = 0.5 |
|
|
best_auc_seen = -1.0 |
|
|
|
|
|
for epoch in range(1, 151): |
|
|
train_one_epoch(model, train_loader, optim, criterion, device) |
|
|
|
|
|
y_true, y_prob = eval_probs(model, val_loader, device) |
|
|
auc = roc_auc_score(y_true, y_prob) |
|
|
thr, f1 = best_f1_threshold(y_true, y_prob) |
|
|
|
|
|
if f1 > best_f1_seen + 1e-4: |
|
|
best_f1_seen = f1 |
|
|
best_thr_seen = thr |
|
|
best_auc_seen = auc |
|
|
bad = 0 |
|
|
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
|
|
else: |
|
|
bad += 1 |
|
|
if bad >= patience: |
|
|
break |
|
|
|
|
|
if best_state is not None: |
|
|
model.load_state_dict(best_state) |
|
|
|
|
|
|
|
|
y_true_val, y_prob_val = eval_probs(model, val_loader, device) |
|
|
best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val) |
|
|
|
|
|
|
|
|
model_path = os.path.join(out_dir, "best_model.pt") |
|
|
torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path) |
|
|
|
|
|
|
|
|
y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False, |
|
|
collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device) |
|
|
|
|
|
save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final, |
|
|
sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None) |
|
|
save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final, |
|
|
sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None) |
|
|
|
|
|
plot_curves(out_dir, y_true_val, y_prob_val) |
|
|
|
|
|
summary = [ |
|
|
"=" * 72, |
|
|
f"MODEL: {model_name}", |
|
|
|
|
|
|
|
|
f"Best Optuna F1 (objective): {best_f1_optuna:.4f}", |
|
|
f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}", |
|
|
f"Best Optuna threshold (val): {best_thr:.4f}", |
|
|
|
|
|
|
|
|
f"Refit best AUC (val): {best_auc_seen:.4f}", |
|
|
f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}", |
|
|
|
|
|
"Best params:", |
|
|
json.dumps(best_params, indent=2), |
|
|
f"Saved model: {model_path}", |
|
|
"=" * 72, |
|
|
] |
|
|
|
|
|
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f: |
|
|
f.write("\n".join(summary)) |
|
|
print("\n".join(summary)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--dataset_path", type=str, required=True) |
|
|
parser.add_argument("--out_dir", type=str, required=True) |
|
|
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True) |
|
|
parser.add_argument("--n_trials", type=int, default=50) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.model in ["mlp", "cnn", "transformer"]: |
|
|
run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0") |
|
|
|