Goshawk_Hedge_Pro / threshold_optimizer.py
GoshawkVortexAI's picture
Create threshold_optimizer.py
e365f22 verified
"""
threshold_optimizer.py β€” Post-training threshold calibration tool.
Run this standalone to re-optimize the probability threshold on new data
WITHOUT retraining the model. Useful for:
- Adapting to regime changes without full retraining
- Testing different optimization objectives
- Out-of-sample threshold validation
The threshold search maximizes expectancy or Sharpe over a held-out dataset.
Usage:
python threshold_optimizer.py --symbols BTC-USDT ETH-USDT --bars 200
python threshold_optimizer.py --objective sharpe
"""
import argparse
import json
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg") # non-interactive backend
import matplotlib.pyplot as plt
sys.path.insert(0, str(Path(__file__).parent))
from ml_config import (
THRESHOLD_PATH,
THRESHOLD_MIN,
THRESHOLD_MAX,
THRESHOLD_STEPS,
THRESHOLD_OBJECTIVE,
TARGET_RR,
ROUND_TRIP_COST,
FEATURE_COLUMNS,
ML_DIR,
)
from ml_filter import TradeFilter
from feature_builder import build_feature_dict, validate_features
from train import build_dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
def compute_threshold_curve(
probs: np.ndarray,
y_true: np.ndarray,
rr: float = TARGET_RR,
cost: float = ROUND_TRIP_COST,
) -> pd.DataFrame:
"""
Sweep threshold grid and compute metrics at each threshold.
Returns DataFrame for analysis and plotting.
"""
thresholds = np.linspace(THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS)
records = []
for t in thresholds:
mask = probs >= t
n = int(mask.sum())
if n < 5:
records.append({
"threshold": t, "n_trades": n,
"win_rate": np.nan, "expectancy": np.nan,
"sharpe": np.nan, "precision": np.nan,
"coverage": 0.0,
})
continue
y_f = y_true[mask]
wr = float(y_f.mean())
exp = wr * rr - (1 - wr) * 1.0 - cost
pnl = np.where(y_f == 1, rr, -1.0) - cost
sh = (pnl.mean() / pnl.std() * np.sqrt(252)) if pnl.std() > 1e-9 else 0.0
cov = n / len(y_true)
records.append({
"threshold": round(t, 4),
"n_trades": n,
"win_rate": round(wr, 4),
"expectancy": round(exp, 4),
"sharpe": round(sh, 4),
"precision": round(wr, 4),
"coverage": round(cov, 4),
})
return pd.DataFrame(records)
def find_optimal_threshold(
curve: pd.DataFrame,
objective: str = THRESHOLD_OBJECTIVE,
min_trades: int = 20,
) -> float:
valid = curve[curve["n_trades"] >= min_trades].dropna(subset=[objective])
if valid.empty:
logger.warning("No valid threshold found β€” using default 0.55")
return 0.55
best_row = valid.loc[valid[objective].idxmax()]
return float(best_row["threshold"])
def plot_threshold_curves(curve: pd.DataFrame, optimal: float, save_path: Path):
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle("Threshold Optimization", fontsize=14, fontweight="bold")
metrics = ["expectancy", "sharpe", "win_rate", "n_trades"]
titles = ["Expectancy per Trade", "Annualized Sharpe", "Win Rate", "# Trades"]
for ax, metric, title in zip(axes.flatten(), metrics, titles):
valid = curve.dropna(subset=[metric])
ax.plot(valid["threshold"], valid[metric], lw=2, color="#1a6bff")
ax.axvline(optimal, color="orange", linestyle="--", lw=1.5, label=f"Optimal={optimal:.3f}")
ax.axhline(0, color="gray", linestyle=":", lw=0.8)
ax.set_title(title, fontsize=11)
ax.set_xlabel("Threshold")
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=120, bbox_inches="tight")
plt.close()
logger.info(f"Threshold curve plot saved β†’ {save_path}")
def main(args):
trade_filter = TradeFilter.load_or_none()
if trade_filter is None:
logger.error("No trained model found. Run train.py first.")
sys.exit(1)
symbols = args.symbols or ["BTC-USDT", "ETH-USDT", "SOL-USDT", "BNB-USDT"]
dataset = build_dataset(symbols, bars=args.bars)
X = dataset[FEATURE_COLUMNS].values.astype(np.float64)
y = dataset["label"].values.astype(np.int32)
feature_dicts = [
{k: float(row[k]) for k in FEATURE_COLUMNS}
for _, row in dataset[FEATURE_COLUMNS].iterrows()
]
probs = trade_filter.predict_batch(feature_dicts)
logger.info(f"Generated {len(probs)} predictions | mean_prob={probs.mean():.4f}")
curve = compute_threshold_curve(probs, y)
optimal = find_optimal_threshold(curve, objective=args.objective)
best_row = curve[curve["threshold"].round(4) == round(optimal, 4)].iloc[0]
logger.info(f"\n=== THRESHOLD OPTIMIZATION RESULT ===")
logger.info(f" Objective: {args.objective}")
logger.info(f" Optimal threshold: {optimal:.4f}")
logger.info(f" Win rate: {best_row['win_rate']:.4f}")
logger.info(f" Expectancy: {best_row['expectancy']:.4f}")
logger.info(f" Sharpe: {best_row['sharpe']:.4f}")
logger.info(f" # Trades: {int(best_row['n_trades'])}")
logger.info(f" Coverage: {best_row['coverage']:.2%}")
# Update threshold file
ML_DIR.mkdir(parents=True, exist_ok=True)
thresh_data = {
"threshold": optimal,
"objective": args.objective,
"win_rate_at_threshold": float(best_row["win_rate"]),
"expectancy_at_threshold": float(best_row["expectancy"]),
"sharpe_at_threshold": float(best_row["sharpe"]),
"n_trades_at_threshold": int(best_row["n_trades"]),
}
with open(THRESHOLD_PATH, "w") as f:
json.dump(thresh_data, f, indent=2)
logger.info(f"Threshold updated β†’ {THRESHOLD_PATH}")
# Save curve CSV
curve_path = ML_DIR / "threshold_curve.csv"
curve.to_csv(curve_path, index=False)
# Plot
plot_path = ML_DIR / "threshold_curve.png"
try:
plot_threshold_curves(curve, optimal, plot_path)
except Exception as e:
logger.warning(f"Plot failed: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Optimize probability threshold")
parser.add_argument("--symbols", nargs="+", default=None)
parser.add_argument("--bars", type=int, default=200)
parser.add_argument("--objective", choices=["expectancy", "sharpe", "win_rate"], default=THRESHOLD_OBJECTIVE)
args = parser.parse_args()
main(args)