""" threshold_optimizer.py — Post-training threshold calibration tool. Run this standalone to re-optimize the probability threshold on new data WITHOUT retraining the model. Useful for: - Adapting to regime changes without full retraining - Testing different optimization objectives - Out-of-sample threshold validation The threshold search maximizes expectancy or Sharpe over a held-out dataset. Usage: python threshold_optimizer.py --symbols BTC-USDT ETH-USDT --bars 200 python threshold_optimizer.py --objective sharpe """ import argparse import json import logging import sys from pathlib import Path import numpy as np import pandas as pd import matplotlib matplotlib.use("Agg") # non-interactive backend import matplotlib.pyplot as plt sys.path.insert(0, str(Path(__file__).parent)) from ml_config import ( THRESHOLD_PATH, THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS, THRESHOLD_OBJECTIVE, TARGET_RR, ROUND_TRIP_COST, FEATURE_COLUMNS, ML_DIR, ) from ml_filter import TradeFilter from feature_builder import build_feature_dict, validate_features from train import build_dataset logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") def compute_threshold_curve( probs: np.ndarray, y_true: np.ndarray, rr: float = TARGET_RR, cost: float = ROUND_TRIP_COST, ) -> pd.DataFrame: """ Sweep threshold grid and compute metrics at each threshold. Returns DataFrame for analysis and plotting. """ thresholds = np.linspace(THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS) records = [] for t in thresholds: mask = probs >= t n = int(mask.sum()) if n < 5: records.append({ "threshold": t, "n_trades": n, "win_rate": np.nan, "expectancy": np.nan, "sharpe": np.nan, "precision": np.nan, "coverage": 0.0, }) continue y_f = y_true[mask] wr = float(y_f.mean()) exp = wr * rr - (1 - wr) * 1.0 - cost pnl = np.where(y_f == 1, rr, -1.0) - cost sh = (pnl.mean() / pnl.std() * np.sqrt(252)) if pnl.std() > 1e-9 else 0.0 cov = n / len(y_true) records.append({ "threshold": round(t, 4), "n_trades": n, "win_rate": round(wr, 4), "expectancy": round(exp, 4), "sharpe": round(sh, 4), "precision": round(wr, 4), "coverage": round(cov, 4), }) return pd.DataFrame(records) def find_optimal_threshold( curve: pd.DataFrame, objective: str = THRESHOLD_OBJECTIVE, min_trades: int = 20, ) -> float: valid = curve[curve["n_trades"] >= min_trades].dropna(subset=[objective]) if valid.empty: logger.warning("No valid threshold found — using default 0.55") return 0.55 best_row = valid.loc[valid[objective].idxmax()] return float(best_row["threshold"]) def plot_threshold_curves(curve: pd.DataFrame, optimal: float, save_path: Path): fig, axes = plt.subplots(2, 2, figsize=(12, 8)) fig.suptitle("Threshold Optimization", fontsize=14, fontweight="bold") metrics = ["expectancy", "sharpe", "win_rate", "n_trades"] titles = ["Expectancy per Trade", "Annualized Sharpe", "Win Rate", "# Trades"] for ax, metric, title in zip(axes.flatten(), metrics, titles): valid = curve.dropna(subset=[metric]) ax.plot(valid["threshold"], valid[metric], lw=2, color="#1a6bff") ax.axvline(optimal, color="orange", linestyle="--", lw=1.5, label=f"Optimal={optimal:.3f}") ax.axhline(0, color="gray", linestyle=":", lw=0.8) ax.set_title(title, fontsize=11) ax.set_xlabel("Threshold") ax.legend(fontsize=9) ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=120, bbox_inches="tight") plt.close() logger.info(f"Threshold curve plot saved → {save_path}") def main(args): trade_filter = TradeFilter.load_or_none() if trade_filter is None: logger.error("No trained model found. Run train.py first.") sys.exit(1) symbols = args.symbols or ["BTC-USDT", "ETH-USDT", "SOL-USDT", "BNB-USDT"] dataset = build_dataset(symbols, bars=args.bars) X = dataset[FEATURE_COLUMNS].values.astype(np.float64) y = dataset["label"].values.astype(np.int32) feature_dicts = [ {k: float(row[k]) for k in FEATURE_COLUMNS} for _, row in dataset[FEATURE_COLUMNS].iterrows() ] probs = trade_filter.predict_batch(feature_dicts) logger.info(f"Generated {len(probs)} predictions | mean_prob={probs.mean():.4f}") curve = compute_threshold_curve(probs, y) optimal = find_optimal_threshold(curve, objective=args.objective) best_row = curve[curve["threshold"].round(4) == round(optimal, 4)].iloc[0] logger.info(f"\n=== THRESHOLD OPTIMIZATION RESULT ===") logger.info(f" Objective: {args.objective}") logger.info(f" Optimal threshold: {optimal:.4f}") logger.info(f" Win rate: {best_row['win_rate']:.4f}") logger.info(f" Expectancy: {best_row['expectancy']:.4f}") logger.info(f" Sharpe: {best_row['sharpe']:.4f}") logger.info(f" # Trades: {int(best_row['n_trades'])}") logger.info(f" Coverage: {best_row['coverage']:.2%}") # Update threshold file ML_DIR.mkdir(parents=True, exist_ok=True) thresh_data = { "threshold": optimal, "objective": args.objective, "win_rate_at_threshold": float(best_row["win_rate"]), "expectancy_at_threshold": float(best_row["expectancy"]), "sharpe_at_threshold": float(best_row["sharpe"]), "n_trades_at_threshold": int(best_row["n_trades"]), } with open(THRESHOLD_PATH, "w") as f: json.dump(thresh_data, f, indent=2) logger.info(f"Threshold updated → {THRESHOLD_PATH}") # Save curve CSV curve_path = ML_DIR / "threshold_curve.csv" curve.to_csv(curve_path, index=False) # Plot plot_path = ML_DIR / "threshold_curve.png" try: plot_threshold_curves(curve, optimal, plot_path) except Exception as e: logger.warning(f"Plot failed: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Optimize probability threshold") parser.add_argument("--symbols", nargs="+", default=None) parser.add_argument("--bars", type=int, default=200) parser.add_argument("--objective", choices=["expectancy", "sharpe", "win_rate"], default=THRESHOLD_OBJECTIVE) args = parser.parse_args() main(args)