Joblib
ynuozhang
update code
baf3373
import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_from_disk, DatasetDict
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
import torch.nn as nn
import optuna
import os
from typing import Dict, Any, Tuple, Optional
import matplotlib.pyplot as plt
from sklearn.metrics import (
f1_score, roc_auc_score, average_precision_score,
precision_recall_curve, roc_curve
)
import json
import joblib
import pandas as pd
import time
from lightning.pytorch import seed_everything
seed_everything(1986)
def infer_in_dim_from_unpooled_ds(ds) -> int:
ex = ds[0]
# ex["embedding"] is (L, H) list/array
return int(len(ex["embedding"][0]))
def load_split(dataset_path):
ds = load_from_disk(dataset_path)
if isinstance(ds, DatasetDict):
return ds["train"], ds["val"]
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
def collate_unpooled(batch):
# batch: list of dicts
lengths = [int(x["length"]) for x in batch]
Lmax = max(lengths)
H = len(batch[0]["embedding"][0]) # 1280
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
for i, x in enumerate(batch):
emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L, H)
L = emb.shape[0]
X[i, :L] = emb
if "attention_mask" in x:
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
M[i, :L] = m[:L]
else:
M[i, :L] = True
return X, M, y
# ======================== Helper functions =========================================
def save_predictions_csv(
out_dir: str,
split_name: str,
y_true: np.ndarray,
y_prob: np.ndarray,
threshold: float,
sequences: Optional[np.ndarray] = None,
):
os.makedirs(out_dir, exist_ok=True)
df = pd.DataFrame({
"y_true": y_true.astype(int),
"y_prob": y_prob.astype(float),
"y_pred": (y_prob >= threshold).astype(int),
})
if sequences is not None:
df.insert(0, "sequence", sequences)
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
os.makedirs(out_dir, exist_ok=True)
# PR
precision, recall, _ = precision_recall_curve(y_true, y_prob)
plt.figure()
plt.plot(recall, precision)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
plt.close()
# ROC
fpr, tpr, _ = roc_curve(y_true, y_prob)
plt.figure()
plt.plot(fpr, tpr)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
plt.close()
# ======================== Shared OPTUNA training scheme =========================================
def best_f1_threshold(y_true, y_prob):
p, r, thr = precision_recall_curve(y_true, y_prob)
f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12)
i = int(np.nanargmax(f1s))
return float(thr[i]), float(f1s[i])
@torch.no_grad()
def eval_probs(model, loader, device):
model.eval()
ys, ps = [], []
for X, M, y in loader:
X, M = X.to(device), M.to(device)
logits = model(X, M)
prob = torch.sigmoid(logits).detach().cpu().numpy()
ys.append(y.numpy())
ps.append(prob)
return np.concatenate(ys), np.concatenate(ps)
def train_one_epoch(model, loader, optim, criterion, device):
model.train()
for X, M, y in loader:
X, M, y = X.to(device), M.to(device), y.to(device)
optim.zero_grad(set_to_none=True)
logits = model(X, M)
loss = criterion(logits, y)
loss.backward()
optim.step()
# ======================== MLP =========================================
# Still need mean pooling along lengths
class MaskedMeanPool(nn.Module):
def forward(self, X, M): # X: (B,L,H), M: (B,L)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom # (B,H)
class MLPClassifier(nn.Module):
def __init__(self, in_dim, hidden=512, dropout=0.1):
super().__init__()
self.pool = MaskedMeanPool()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def forward(self, X, M):
z = self.pool(X, M)
return self.net(z).squeeze(-1) # logits
# ======================== CNN =========================================
# Treat 1280 dimensions as channels
class CNNClassifier(nn.Module):
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
super().__init__()
blocks = []
ch = in_ch
for _ in range(layers):
blocks += [
nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
nn.GELU(),
nn.Dropout(dropout),
]
ch = c
self.conv = nn.Sequential(*blocks)
self.head = nn.Linear(c, 1)
def forward(self, X, M):
# X: (B,L,H) -> (B,H,L)
Xc = X.transpose(1, 2)
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
# masked mean pool over L
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
return self.head(pooled).squeeze(-1)
# ========================== Transformer ====================================
class TransformerClassifier(nn.Module):
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
super().__init__()
self.proj = nn.Linear(in_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=ff,
dropout=dropout, batch_first=True, activation="gelu"
)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
self.head = nn.Linear(d_model, 1)
def forward(self, X, M):
# src_key_padding_mask: True = pad positions
pad_mask = ~M
Z = self.proj(X) # (B,L,d)
Z = self.enc(Z, src_key_padding_mask=pad_mask) # (B,L,d)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
pooled = (Z * Mf).sum(dim=1) / denom
return self.head(pooled).squeeze(-1)
# ========================== OPTUNA ====================================
def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"):
# hyperparams shared
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True)
dropout = trial.suggest_float("dropout", 0.0, 0.5)
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
if model_name == "mlp":
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout)
elif model_name == "cnn":
c = trial.suggest_categorical("channels", [128, 256, 512])
k = trial.suggest_categorical("kernel", [3, 5, 7])
layers = trial.suggest_int("layers", 1, 4)
model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
elif model_name == "transformer":
d = trial.suggest_categorical("d_model", [128, 256, 384])
nhead = trial.suggest_categorical("nhead", [4, 8])
layers = trial.suggest_int("layers", 1, 4)
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
else:
raise ValueError(model_name)
model = model.to(device)
# class imbalance handling
ytr = np.asarray(train_ds["label"], dtype=np.int64)
pos = ytr.sum()
neg = len(ytr) - pos
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
best_f1 = -1.0
patience = 8
bad = 0
for epoch in range(1, 51):
train_one_epoch(model, train_loader, optim, criterion, device)
y_true, y_prob = eval_probs(model, val_loader, device)
auc = roc_auc_score(y_true, y_prob)
thr, f1 = best_f1_threshold(y_true, y_prob)
trial.set_user_attr("val_auc", float(auc))
trial.set_user_attr("val_f1", float(f1))
trial.set_user_attr("val_thr", float(thr))
# prune
trial.report(f1, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
if f1 > best_f1 + 1e-4:
best_f1 = f1
bad = 0
else:
bad += 1
if bad >= patience:
break
return best_f1
def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"):
os.makedirs(out_dir, exist_ok=True)
train_ds, val_ds = load_split(dataset_path)
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials)
trials_df = study.trials_dataframe()
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
best = study.best_trial
best_params = dict(best.params)
best_f1_optuna = float(best.value)
best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan))
best_thr = float(best.user_attrs.get("val_thr", 0.5))
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
# --- Refit best model ---
batch_size = int(best_params.get("batch_size", 32))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
# Rebuild
dropout = float(best_params.get("dropout", 0.1))
if model_name == "mlp":
model = MLPClassifier(
in_dim=in_dim,
hidden=int(best_params["hidden"]),
dropout=dropout,
)
elif model_name == "cnn":
model = CNNClassifier(
in_ch=in_dim,
c=int(best_params["channels"]),
k=int(best_params["kernel"]),
layers=int(best_params["layers"]),
dropout=dropout,
)
elif model_name == "transformer":
model = TransformerClassifier(
in_dim=in_dim,
d_model=int(best_params["d_model"]),
nhead=int(best_params["nhead"]),
layers=int(best_params["layers"]),
ff=int(best_params["ff"]),
dropout=dropout,
)
else:
raise ValueError(model_name)
model = model.to(device)
# loss + optimizer
ytr = np.asarray(train_ds["label"], dtype=np.int64)
pos = ytr.sum()
neg = len(ytr) - pos
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
lr = float(best_params["lr"])
wd = float(best_params["weight_decay"])
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
# train longer with early stopping on AUC
best_f1_seen, bad, patience = -1.0, 0, 12
best_state = None
best_thr_seen = 0.5
best_auc_seen = -1.0
for epoch in range(1, 151):
train_one_epoch(model, train_loader, optim, criterion, device)
y_true, y_prob = eval_probs(model, val_loader, device)
auc = roc_auc_score(y_true, y_prob)
thr, f1 = best_f1_threshold(y_true, y_prob)
if f1 > best_f1_seen + 1e-4:
best_f1_seen = f1
best_thr_seen = thr
best_auc_seen = auc
bad = 0
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
else:
bad += 1
if bad >= patience:
break
if best_state is not None:
model.load_state_dict(best_state)
# final preds + threshold picked on val
y_true_val, y_prob_val = eval_probs(model, val_loader, device)
best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
# save model
model_path = os.path.join(out_dir, "best_model.pt")
torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path)
# train preds
y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device)
save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final,
sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None)
save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final,
sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None)
plot_curves(out_dir, y_true_val, y_prob_val)
summary = [
"=" * 72,
f"MODEL: {model_name}",
# Optuna results (objective = F1)
f"Best Optuna F1 (objective): {best_f1_optuna:.4f}",
f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}",
f"Best Optuna threshold (val): {best_thr:.4f}",
# Refit results
f"Refit best AUC (val): {best_auc_seen:.4f}",
f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}",
"Best params:",
json.dumps(best_params, indent=2),
f"Saved model: {model_path}",
"=" * 72,
]
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
f.write("\n".join(summary))
print("\n".join(summary))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
parser.add_argument("--n_trials", type=int, default=50)
args = parser.parse_args()
if args.model in ["mlp", "cnn", "transformer"]:
run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0")