| | """ |
| | threshold_optimizer.py β Post-training threshold calibration tool. |
| | |
| | Run this standalone to re-optimize the probability threshold on new data |
| | WITHOUT retraining the model. Useful for: |
| | - Adapting to regime changes without full retraining |
| | - Testing different optimization objectives |
| | - Out-of-sample threshold validation |
| | |
| | The threshold search maximizes expectancy or Sharpe over a held-out dataset. |
| | |
| | Usage: |
| | python threshold_optimizer.py --symbols BTC-USDT ETH-USDT --bars 200 |
| | python threshold_optimizer.py --objective sharpe |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import sys |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import matplotlib |
| | matplotlib.use("Agg") |
| | import matplotlib.pyplot as plt |
| |
|
| | sys.path.insert(0, str(Path(__file__).parent)) |
| |
|
| | from ml_config import ( |
| | THRESHOLD_PATH, |
| | THRESHOLD_MIN, |
| | THRESHOLD_MAX, |
| | THRESHOLD_STEPS, |
| | THRESHOLD_OBJECTIVE, |
| | TARGET_RR, |
| | ROUND_TRIP_COST, |
| | FEATURE_COLUMNS, |
| | ML_DIR, |
| | ) |
| | from ml_filter import TradeFilter |
| | from feature_builder import build_feature_dict, validate_features |
| | from train import build_dataset |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| |
|
| |
|
| | def compute_threshold_curve( |
| | probs: np.ndarray, |
| | y_true: np.ndarray, |
| | rr: float = TARGET_RR, |
| | cost: float = ROUND_TRIP_COST, |
| | ) -> pd.DataFrame: |
| | """ |
| | Sweep threshold grid and compute metrics at each threshold. |
| | Returns DataFrame for analysis and plotting. |
| | """ |
| | thresholds = np.linspace(THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS) |
| | records = [] |
| |
|
| | for t in thresholds: |
| | mask = probs >= t |
| | n = int(mask.sum()) |
| | if n < 5: |
| | records.append({ |
| | "threshold": t, "n_trades": n, |
| | "win_rate": np.nan, "expectancy": np.nan, |
| | "sharpe": np.nan, "precision": np.nan, |
| | "coverage": 0.0, |
| | }) |
| | continue |
| |
|
| | y_f = y_true[mask] |
| | wr = float(y_f.mean()) |
| | exp = wr * rr - (1 - wr) * 1.0 - cost |
| | pnl = np.where(y_f == 1, rr, -1.0) - cost |
| | sh = (pnl.mean() / pnl.std() * np.sqrt(252)) if pnl.std() > 1e-9 else 0.0 |
| | cov = n / len(y_true) |
| |
|
| | records.append({ |
| | "threshold": round(t, 4), |
| | "n_trades": n, |
| | "win_rate": round(wr, 4), |
| | "expectancy": round(exp, 4), |
| | "sharpe": round(sh, 4), |
| | "precision": round(wr, 4), |
| | "coverage": round(cov, 4), |
| | }) |
| |
|
| | return pd.DataFrame(records) |
| |
|
| |
|
| | def find_optimal_threshold( |
| | curve: pd.DataFrame, |
| | objective: str = THRESHOLD_OBJECTIVE, |
| | min_trades: int = 20, |
| | ) -> float: |
| | valid = curve[curve["n_trades"] >= min_trades].dropna(subset=[objective]) |
| | if valid.empty: |
| | logger.warning("No valid threshold found β using default 0.55") |
| | return 0.55 |
| | best_row = valid.loc[valid[objective].idxmax()] |
| | return float(best_row["threshold"]) |
| |
|
| |
|
| | def plot_threshold_curves(curve: pd.DataFrame, optimal: float, save_path: Path): |
| | fig, axes = plt.subplots(2, 2, figsize=(12, 8)) |
| | fig.suptitle("Threshold Optimization", fontsize=14, fontweight="bold") |
| |
|
| | metrics = ["expectancy", "sharpe", "win_rate", "n_trades"] |
| | titles = ["Expectancy per Trade", "Annualized Sharpe", "Win Rate", "# Trades"] |
| |
|
| | for ax, metric, title in zip(axes.flatten(), metrics, titles): |
| | valid = curve.dropna(subset=[metric]) |
| | ax.plot(valid["threshold"], valid[metric], lw=2, color="#1a6bff") |
| | ax.axvline(optimal, color="orange", linestyle="--", lw=1.5, label=f"Optimal={optimal:.3f}") |
| | ax.axhline(0, color="gray", linestyle=":", lw=0.8) |
| | ax.set_title(title, fontsize=11) |
| | ax.set_xlabel("Threshold") |
| | ax.legend(fontsize=9) |
| | ax.grid(True, alpha=0.3) |
| |
|
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=120, bbox_inches="tight") |
| | plt.close() |
| | logger.info(f"Threshold curve plot saved β {save_path}") |
| |
|
| |
|
| | def main(args): |
| | trade_filter = TradeFilter.load_or_none() |
| | if trade_filter is None: |
| | logger.error("No trained model found. Run train.py first.") |
| | sys.exit(1) |
| |
|
| | symbols = args.symbols or ["BTC-USDT", "ETH-USDT", "SOL-USDT", "BNB-USDT"] |
| | dataset = build_dataset(symbols, bars=args.bars) |
| |
|
| | X = dataset[FEATURE_COLUMNS].values.astype(np.float64) |
| | y = dataset["label"].values.astype(np.int32) |
| |
|
| | feature_dicts = [ |
| | {k: float(row[k]) for k in FEATURE_COLUMNS} |
| | for _, row in dataset[FEATURE_COLUMNS].iterrows() |
| | ] |
| | probs = trade_filter.predict_batch(feature_dicts) |
| |
|
| | logger.info(f"Generated {len(probs)} predictions | mean_prob={probs.mean():.4f}") |
| |
|
| | curve = compute_threshold_curve(probs, y) |
| | optimal = find_optimal_threshold(curve, objective=args.objective) |
| | best_row = curve[curve["threshold"].round(4) == round(optimal, 4)].iloc[0] |
| |
|
| | logger.info(f"\n=== THRESHOLD OPTIMIZATION RESULT ===") |
| | logger.info(f" Objective: {args.objective}") |
| | logger.info(f" Optimal threshold: {optimal:.4f}") |
| | logger.info(f" Win rate: {best_row['win_rate']:.4f}") |
| | logger.info(f" Expectancy: {best_row['expectancy']:.4f}") |
| | logger.info(f" Sharpe: {best_row['sharpe']:.4f}") |
| | logger.info(f" # Trades: {int(best_row['n_trades'])}") |
| | logger.info(f" Coverage: {best_row['coverage']:.2%}") |
| |
|
| | |
| | ML_DIR.mkdir(parents=True, exist_ok=True) |
| | thresh_data = { |
| | "threshold": optimal, |
| | "objective": args.objective, |
| | "win_rate_at_threshold": float(best_row["win_rate"]), |
| | "expectancy_at_threshold": float(best_row["expectancy"]), |
| | "sharpe_at_threshold": float(best_row["sharpe"]), |
| | "n_trades_at_threshold": int(best_row["n_trades"]), |
| | } |
| | with open(THRESHOLD_PATH, "w") as f: |
| | json.dump(thresh_data, f, indent=2) |
| | logger.info(f"Threshold updated β {THRESHOLD_PATH}") |
| |
|
| | |
| | curve_path = ML_DIR / "threshold_curve.csv" |
| | curve.to_csv(curve_path, index=False) |
| |
|
| | |
| | plot_path = ML_DIR / "threshold_curve.png" |
| | try: |
| | plot_threshold_curves(curve, optimal, plot_path) |
| | except Exception as e: |
| | logger.warning(f"Plot failed: {e}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Optimize probability threshold") |
| | parser.add_argument("--symbols", nargs="+", default=None) |
| | parser.add_argument("--bars", type=int, default=200) |
| | parser.add_argument("--objective", choices=["expectancy", "sharpe", "win_rate"], default=THRESHOLD_OBJECTIVE) |
| | args = parser.parse_args() |
| | main(args) |
| |
|