#!/usr/bin/env python3
"""Calibrate market regimes from roughly 10 years of historical prices.

Outputs:
- historical_regime_config.json: regime return parameters and transition matrix
- historical_regime_stats.csv: human-readable regime statistics
- historical_regime_calibration_report.md: summary for the dashboard/manual
"""

from __future__ import annotations

import argparse
import csv
import json
import math
from datetime import datetime, timezone
from pathlib import Path


TICKERS = [
    "SPY",
    "QQQ",
    "IWM",
    "SMH",
    "SOXX",
    "SOXL",
    "SOXS",
    "NVDA",
    "AVGO",
    "AMD",
    "MU",
    "TSM",
    "ASML",
    "^VIX",
]

ASSETS = ["basket", "quality", "leader", "soxl", "soxs"]
REGIME_ORDER = ["risk_on", "selective", "chop", "selloff", "rebound", "crash"]


def pct_change(series):
    return series.pct_change()


def classify_regimes(close):
    qqq = close["QQQ"]
    smh = close["SMH"]
    spy = close["SPY"]
    vix = close["^VIX"]

    qqq_ret_1 = pct_change(qqq)
    qqq_ret_3 = qqq.pct_change(3)
    qqq_ret_5 = qqq.pct_change(5)
    qqq_ret_20 = qqq.pct_change(20)
    smh_ret_3 = smh.pct_change(3)
    smh_ret_5 = smh.pct_change(5)
    smh_ret_20 = smh.pct_change(20)
    spy_ret_20 = spy.pct_change(20)
    vix_change_5 = vix.pct_change(5)
    qqq_ma50 = qqq.rolling(50).mean()
    qqq_ma200 = qqq.rolling(200).mean()
    smh_ma50 = smh.rolling(50).mean()
    qqq_high_60 = qqq.rolling(60).max()
    qqq_drawdown_60 = qqq / qqq_high_60 - 1.0

    regimes = {}
    for idx in close.index:
        if any(math.isnan(x) for x in [
            qqq_ret_1.loc[idx],
            qqq_ret_5.loc[idx],
            qqq_ret_20.loc[idx],
            smh_ret_20.loc[idx],
            vix_change_5.loc[idx],
            qqq_ma50.loc[idx],
            qqq_ma200.loc[idx],
            smh_ma50.loc[idx],
            qqq_drawdown_60.loc[idx],
        ]):
            continue

        if (qqq_ret_5.loc[idx] <= -0.075 or smh_ret_5.loc[idx] <= -0.105 or (vix.loc[idx] >= 35 and qqq_ret_3.loc[idx] < -0.035)):
            regime = "crash"
        elif (qqq.loc[idx] < qqq_ma50.loc[idx] and smh.loc[idx] < smh_ma50.loc[idx] and (qqq_ret_5.loc[idx] < -0.035 or vix_change_5.loc[idx] > 0.12)):
            regime = "selloff"
        elif qqq_drawdown_60.loc[idx] < -0.05 and qqq_ret_3.loc[idx] > 0.025 and smh_ret_3.loc[idx] > 0.030:
            regime = "rebound"
        elif qqq.loc[idx] > qqq_ma50.loc[idx] and qqq.loc[idx] > qqq_ma200.loc[idx] and smh.loc[idx] > smh_ma50.loc[idx] and qqq_ret_20.loc[idx] > 0 and smh_ret_20.loc[idx] > 0:
            regime = "risk_on"
        elif (qqq.loc[idx] > qqq_ma50.loc[idx] or smh.loc[idx] > smh_ma50.loc[idx]) and (qqq_ret_20.loc[idx] > spy_ret_20.loc[idx] or smh_ret_20.loc[idx] > qqq_ret_20.loc[idx]):
            regime = "selective"
        else:
            regime = "chop"
        regimes[idx] = regime

    return regimes


def build_asset_returns(close):
    returns = close.pct_change()
    out = {}
    out["basket"] = returns[["SPY", "QQQ", "IWM", "SMH"]].mean(axis=1)
    out["quality"] = returns[["SPY", "QQQ"]].mean(axis=1)
    leader_cols = [ticker for ticker in ["NVDA", "AVGO", "AMD", "MU", "TSM", "ASML"] if ticker in returns.columns]
    out["leader"] = returns[leader_cols].mean(axis=1)
    out["soxl"] = returns["SOXL"]
    out["soxs"] = returns["SOXS"]
    return out


def mean_std(values):
    clean = [float(v) for v in values if v == v]
    if not clean:
        return 0.0, 0.02
    mean = sum(clean) / len(clean)
    variance = sum((x - mean) ** 2 for x in clean) / max(len(clean) - 1, 1)
    return mean, max(math.sqrt(variance), 0.0001)


def smooth_transition_counts(regime_series):
    counts = {src: {dst: 1 for dst in REGIME_ORDER} for src in REGIME_ORDER}
    ordered = [regime for _, regime in sorted(regime_series.items(), key=lambda item: item[0])]
    for src, dst in zip(ordered, ordered[1:]):
        counts[src][dst] += 1
    transitions = {}
    for src, dst_counts in counts.items():
        total = sum(dst_counts.values())
        transitions[src] = [(dst, dst_counts[dst] / total) for dst in REGIME_ORDER]
    return transitions


def calibrate(period: str):
    import yfinance as yf

    data = yf.download(TICKERS, period=period, auto_adjust=True, progress=False, threads=False)
    if data.empty:
        raise RuntimeError("No data downloaded from yfinance")

    if "Close" in data:
        close = data["Close"].dropna(how="all")
    else:
        close = data.dropna(how="all")

    missing = [ticker for ticker in TICKERS if ticker not in close.columns]
    if missing:
        raise RuntimeError(f"Missing tickers in downloaded data: {missing}")

    close = close.dropna(subset=["SPY", "QQQ", "SMH", "SOXX", "SOXL", "SOXS", "^VIX"])
    regimes_by_date = classify_regimes(close)
    asset_returns = build_asset_returns(close)

    regime_config = {}
    stats_rows = []
    for regime in REGIME_ORDER:
        dates = [date for date, label in regimes_by_date.items() if label == regime]
        regime_config[regime] = {}
        for asset in ASSETS:
            values = [asset_returns[asset].loc[date] for date in dates if date in asset_returns[asset].index]
            mean, stdev = mean_std(values)
            regime_config[regime][asset] = [mean, stdev]
            stats_rows.append({
                "regime": regime,
                "asset": asset,
                "days": len(values),
                "mean_daily_return": mean,
                "stdev_daily_return": stdev,
                "annualized_mean_simple": mean * 252,
                "annualized_vol_simple": stdev * math.sqrt(252),
            })

    transitions = smooth_transition_counts(regimes_by_date)
    counts = {regime: list(regimes_by_date.values()).count(regime) for regime in REGIME_ORDER}
    payload = {
        "as_of_utc": datetime.now(timezone.utc).isoformat(),
        "period": period,
        "source": "yfinance adjusted close",
        "tickers": TICKERS,
        "regime_counts": counts,
        "regimes": regime_config,
        "transitions": transitions,
        "notes": [
            "Historical calibration is a guide, not a guarantee.",
            "SOXL/SOXS are daily leveraged products; long-horizon simulation still requires daily revalidation.",
            "Classification uses QQQ/SMH/SOXX/VIX trend and drawdown rules.",
        ],
    }
    return payload, stats_rows


def write_outputs(payload, stats_rows, out_dir: Path):
    config_path = out_dir / "historical_regime_config.json"
    stats_path = out_dir / "historical_regime_stats.csv"
    report_path = out_dir / "historical_regime_calibration_report.md"

    config_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))

    with stats_path.open("w", newline="") as handle:
        writer = csv.DictWriter(handle, fieldnames=list(stats_rows[0].keys()))
        writer.writeheader()
        writer.writerows(stats_rows)

    counts = payload["regime_counts"]
    total = sum(counts.values())
    lines = [
        "# Historical Regime Calibration Report",
        "",
        f"- As of UTC: {payload['as_of_utc']}",
        f"- Period: {payload['period']}",
        f"- Source: {payload['source']}",
        "- Purpose: calibrate 1-year strategy simulations from the last 10 years of market regimes.",
        "",
        "## Regime Mix",
        "",
        "| Regime | Days | Share |",
        "| --- | ---: | ---: |",
    ]
    for regime in REGIME_ORDER:
        days = counts[regime]
        share = days / total if total else 0
        lines.append(f"| {regime} | {days} | {share:.2%} |")

    lines.extend([
        "",
        "## Required Use",
        "",
        "- Strategy simulations should run at least 252 trading days.",
        "- Chart and trend analysis should reference this 10-year regime calibration before selecting ATTACK/GUARDED ATTACK/DEFENSE/REST.",
        "- If live data download fails, use the last saved calibration and mark it stale.",
        "",
        "## Files",
        "",
        f"- Config: `{config_path}`",
        f"- Stats: `{stats_path}`",
    ])
    report_path.write_text("\n".join(lines) + "\n")


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--period", default="10y")
    args = parser.parse_args()
    out_dir = Path(__file__).resolve().parent
    payload, stats_rows = calibrate(args.period)
    write_outputs(payload, stats_rows, out_dir)
    print(json.dumps({
        "period": payload["period"],
        "as_of_utc": payload["as_of_utc"],
        "regime_counts": payload["regime_counts"],
    }, ensure_ascii=False, indent=2))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

