"""
backtest_ic.py - Framework per backtest di Iron Condor su dati storici
======================================================================
Companion code per "Trading con le Opzioni - Strategie Operative"
di Pierpaolo Marturano (Core Matrix S.r.l.)

Framework per testare strategie Iron Condor con regole sistematiche:
- Entry a 45 DTE
- Exit a 21 DTE o al 50% del profitto
- Stop loss a 2x il credito ricevuto

Requisiti: numpy, matplotlib
Compatibile con Python 3.10+

Nota: questo framework usa dati simulati per dimostrazione.
Per backtest reali, sostituire con dati storici da CBOE, OptionMetrics,
o altre fonti.
"""

import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class TradeResult:
    """Risultato di un singolo trade."""
    entry_date: int          # indice del giorno di entry
    exit_date: int           # indice del giorno di exit
    entry_price: float       # prezzo sottostante all'entry
    exit_price: float        # prezzo sottostante all'exit
    credit: float            # credito ricevuto
    pnl: float              # P&L realizzato
    exit_reason: str         # "target", "stop", "expiry", "dte_exit"
    days_held: int           # giorni in posizione
    put_short_strike: float
    call_short_strike: float


@dataclass
class BacktestConfig:
    """Configurazione del backtest."""
    # Parametri di entry
    entry_dte: int = 45                # DTE all'apertura
    delta_short_strike: float = 0.16   # Delta per strike short (16-delta)
    wing_width: float = 50.0           # Larghezza delle ali

    # Parametri di gestione
    profit_target_pct: float = 0.50    # Chiudi al 50% del credito
    stop_loss_pct: float = 2.0         # Stop a 2x il credito
    exit_dte: int = 21                 # Chiudi a 21 DTE

    # Parametri di simulazione
    iv_to_use: float = 0.16            # IV per stima degli strike
    risk_free_rate: float = 0.045


@dataclass
class BacktestResults:
    """Risultati aggregati del backtest."""
    trades: list[TradeResult] = field(default_factory=list)

    @property
    def n_trades(self) -> int:
        return len(self.trades)

    @property
    def win_rate(self) -> float:
        if not self.trades:
            return 0.0
        wins = sum(1 for t in self.trades if t.pnl > 0)
        return wins / len(self.trades)

    @property
    def total_pnl(self) -> float:
        return sum(t.pnl for t in self.trades)

    @property
    def avg_pnl(self) -> float:
        if not self.trades:
            return 0.0
        return self.total_pnl / len(self.trades)

    @property
    def avg_winner(self) -> float:
        winners = [t.pnl for t in self.trades if t.pnl > 0]
        return np.mean(winners) if winners else 0.0

    @property
    def avg_loser(self) -> float:
        losers = [t.pnl for t in self.trades if t.pnl <= 0]
        return np.mean(losers) if losers else 0.0

    @property
    def max_drawdown(self) -> float:
        if not self.trades:
            return 0.0
        cumulative = np.cumsum([t.pnl for t in self.trades])
        peak = np.maximum.accumulate(cumulative)
        drawdown = cumulative - peak
        return float(np.min(drawdown))

    @property
    def sharpe_ratio(self) -> float:
        """Sharpe annualizzato (assume ~12 trades/anno con 45 DTE)."""
        pnls = [t.pnl for t in self.trades]
        if not pnls or np.std(pnls) == 0:
            return 0.0
        trades_per_year = 365 / 30  # circa 12 trades/anno
        return (np.mean(pnls) / np.std(pnls)) * np.sqrt(trades_per_year)

    @property
    def profit_factor(self) -> float:
        gross_profit = sum(t.pnl for t in self.trades if t.pnl > 0)
        gross_loss = abs(sum(t.pnl for t in self.trades if t.pnl <= 0))
        if gross_loss == 0:
            return float('inf')
        return gross_profit / gross_loss

    @property
    def avg_days_held(self) -> float:
        if not self.trades:
            return 0.0
        return np.mean([t.days_held for t in self.trades])


def generate_synthetic_prices(S0: float = 5500, n_days: int = 2520,
                              annual_return: float = 0.08,
                              annual_vol: float = 0.16,
                              seed: int = 42) -> np.ndarray:
    """
    Genera una serie di prezzi sintetici per il backtest con regime switching.
    2520 giorni ≈ 10 anni di trading.

    Il modello alterna periodi di bassa volatilità (calmo) e alta volatilità
    (stress), rendendo i risultati più realistici rispetto a una vol costante.

    Per backtest reali, sostituire questa funzione con il caricamento
    di dati storici (es. SPX daily close da Yahoo Finance o simili).
    """
    rng = np.random.default_rng(seed)
    dt = 1 / 252

    # Regime switching: alterna periodi calmi e stressati
    vol_regime = np.ones(n_days) * annual_vol
    i = 0
    while i < n_days:
        # Periodo calmo: 30-120 giorni
        calm_days = rng.integers(30, 120)
        end_calm = min(i + calm_days, n_days)
        vol_regime[i:end_calm] = annual_vol * rng.uniform(0.7, 1.1)
        i = end_calm

        # Periodo stressato: 10-40 giorni (vol 1.5x-2.5x)
        if i < n_days:
            stress_days = rng.integers(10, 40)
            end_stress = min(i + stress_days, n_days)
            vol_regime[i:end_stress] = annual_vol * rng.uniform(1.5, 2.5)
            i = end_stress

    # Aggiungi occasional gap (es. overnight, earnings)
    n_gaps = n_days // 60  # circa 1 gap ogni 60 giorni
    gap_days = rng.choice(n_days, size=n_gaps, replace=False)
    gap_sizes = rng.choice([-1, 1], size=n_gaps) * rng.uniform(0.015, 0.035, size=n_gaps)

    log_returns = np.zeros(n_days)
    for day in range(n_days):
        drift_day = (annual_return - 0.5 * vol_regime[day]**2) * dt
        diffusion_day = vol_regime[day] * np.sqrt(dt)
        log_returns[day] = drift_day + diffusion_day * rng.standard_normal()

    # Applica gap
    for gap_day, gap_size in zip(gap_days, gap_sizes):
        log_returns[gap_day] += gap_size

    prices = S0 * np.exp(np.cumsum(log_returns))
    prices = np.insert(prices, 0, S0)

    return prices


def estimate_strikes(S: float, T: float, sigma: float, delta_target: float,
                     r: float = 0.045) -> tuple[float, float]:
    """
    Stima gli strike per un dato delta target usando la formula inversa.

    Returns: (put_strike, call_strike)
    """
    from scipy.stats import norm as norm_dist

    # Approssimazione: strike ≈ S * exp(±z * sigma * sqrt(T))
    # dove z è il quantile corrispondente al delta
    z = norm_dist.ppf(1 - delta_target)

    call_strike = S * np.exp(z * sigma * np.sqrt(T) - (r - 0.5*sigma**2)*T)
    put_strike = S * np.exp(-z * sigma * np.sqrt(T) - (r - 0.5*sigma**2)*T)

    # Arrotonda a multipli di 5 (strike SPX)
    call_strike = np.ceil(call_strike / 5) * 5
    put_strike = np.floor(put_strike / 5) * 5

    return float(put_strike), float(call_strike)


def estimate_credit(S: float, put_K: float, call_K: float,
                    wing_width: float, T: float, sigma: float) -> float:
    """
    Stima il credito netto di un iron condor basandosi su IV e distanza.
    Approssimazione per simulazione.
    """
    # Vega * IV normalizzata * fattore di decadimento
    put_distance = (S - put_K) / S
    call_distance = (call_K - S) / S

    # Premio approssimato dalla formula di Black-Scholes semplificata
    T_sqrt = np.sqrt(T)
    put_premium = S * sigma * T_sqrt * norm.pdf(put_distance / (sigma * T_sqrt))
    call_premium = S * sigma * T_sqrt * norm.pdf(call_distance / (sigma * T_sqrt))

    # Il credit spread è una frazione del premium delle short
    credit = (put_premium + call_premium) * 0.6

    return max(credit, 1.0)  # minimo $1


def ic_pnl_at_exit(S_entry: float, S_exit: float,
                   put_short_K: float, call_short_K: float,
                   put_long_K: float, call_long_K: float,
                   credit: float, days_held: int, total_dte: int) -> float:
    """
    Calcola il P&L approssimato dell'iron condor al momento dell'exit.

    Per exit prima della scadenza, interpola tra il credito iniziale
    e il payoff a scadenza basandosi sul theta decay.
    """
    # Payoff intrinseco a scadenza
    put_intrinsic = max(0, put_short_K - S_exit) - max(0, put_long_K - S_exit)
    call_intrinsic = max(0, S_exit - call_short_K) - max(0, S_exit - call_long_K)
    payoff_at_expiry = credit - put_intrinsic - call_intrinsic

    # Fattore di theta decay (approssimazione sqrt)
    time_passed_ratio = days_held / total_dte
    theta_decay = np.sqrt(time_passed_ratio)  # theta accelera

    # Interpola: il P&L reale è tra il credito pieno e il payoff a scadenza
    # Se il sottostante è dentro le ali, il theta lavora a favore
    if put_short_K <= S_exit <= call_short_K:
        # Dentro le ali: guadagniamo theta
        pnl = credit * theta_decay * 0.7  # conservativo
    else:
        # Fuori dalle ali: perdiamo
        pnl = payoff_at_expiry * theta_decay + credit * (1 - theta_decay) * 0.3

    return pnl


def run_backtest(prices: np.ndarray, config: BacktestConfig) -> BacktestResults:
    """
    Esegue il backtest della strategia Iron Condor.

    Parameters
    ----------
    prices : np.ndarray - Serie di prezzi giornalieri del sottostante
    config : BacktestConfig - Configurazione della strategia

    Returns
    -------
    BacktestResults
    """
    results = BacktestResults()
    n = len(prices)

    i = 0
    while i < n - config.entry_dte:
        S_entry = prices[i]
        T = config.entry_dte / 365

        # Determina gli strike
        put_short_K, call_short_K = estimate_strikes(
            S_entry, T, config.iv_to_use, config.delta_short_strike,
            config.risk_free_rate
        )
        put_long_K = put_short_K - config.wing_width
        call_long_K = call_short_K + config.wing_width

        # Stima il credito
        credit = estimate_credit(S_entry, put_short_K, call_short_K,
                                config.wing_width, T, config.iv_to_use)

        # Simula la gestione giorno per giorno
        exit_reason = "expiry"
        exit_day = min(i + config.entry_dte, n - 1)

        for day in range(1, config.entry_dte + 1):
            if i + day >= n:
                exit_day = n - 1
                break

            S_current = prices[i + day]
            days_remaining = config.entry_dte - day

            # Calcola P&L corrente
            current_pnl = ic_pnl_at_exit(
                S_entry, S_current, put_short_K, call_short_K,
                put_long_K, call_long_K, credit, day, config.entry_dte
            )

            # Check profit target
            if current_pnl >= credit * config.profit_target_pct:
                exit_day = i + day
                exit_reason = "target"
                break

            # Check stop loss
            if current_pnl <= -credit * config.stop_loss_pct:
                exit_day = i + day
                exit_reason = "stop"
                break

            # Check DTE exit
            if days_remaining <= config.exit_dte:
                exit_day = i + day
                exit_reason = "dte_exit"
                break

        # Calcola P&L finale
        days_held = exit_day - i
        S_exit = prices[exit_day]

        if exit_reason == "expiry":
            # Payoff a scadenza
            put_intrinsic = max(0, put_short_K - S_exit) - max(0, put_long_K - S_exit)
            call_intrinsic = max(0, S_exit - call_short_K) - max(0, S_exit - call_long_K)
            final_pnl = credit - put_intrinsic - call_intrinsic
        else:
            final_pnl = ic_pnl_at_exit(
                S_entry, S_exit, put_short_K, call_short_K,
                put_long_K, call_long_K, credit, days_held, config.entry_dte
            )

        # Registra il trade
        trade = TradeResult(
            entry_date=i,
            exit_date=exit_day,
            entry_price=S_entry,
            exit_price=S_exit,
            credit=credit,
            pnl=final_pnl,
            exit_reason=exit_reason,
            days_held=days_held,
            put_short_strike=put_short_K,
            call_short_strike=call_short_K,
        )
        results.trades.append(trade)

        # Prossimo trade: dopo un gap di cooldown
        i = exit_day + 5  # 5 giorni di cooldown tra trades

    return results


def plot_backtest_results(results: BacktestResults, prices: np.ndarray,
                          save_path: Optional[str] = None) -> None:
    """Visualizza i risultati del backtest."""
    fig, axes = plt.subplots(3, 1, figsize=(14, 12))

    # 1. Equity curve
    ax1 = axes[0]
    cumulative_pnl = np.cumsum([t.pnl for t in results.trades])
    ax1.plot(cumulative_pnl, color='#0d5c4d', linewidth=1.5)
    ax1.fill_between(range(len(cumulative_pnl)), cumulative_pnl, 0,
                    where=(cumulative_pnl >= 0), color='#10b981', alpha=0.2)
    ax1.fill_between(range(len(cumulative_pnl)), cumulative_pnl, 0,
                    where=(cumulative_pnl < 0), color='#ef4444', alpha=0.2)
    ax1.axhline(y=0, color='gray', linewidth=0.8)
    ax1.set_title('Equity Curve - Iron Condor Backtest', fontweight='bold', fontsize=12)
    ax1.set_xlabel('Trade #')
    ax1.set_ylabel('P&L Cumulativo ($)')
    ax1.grid(True, alpha=0.3)

    # 2. Distribuzione P&L per trade
    ax2 = axes[1]
    pnls = [t.pnl for t in results.trades]
    colors = ['#10b981' if p > 0 else '#ef4444' for p in pnls]
    ax2.bar(range(len(pnls)), pnls, color=colors, alpha=0.7, width=1.0)
    ax2.axhline(y=0, color='gray', linewidth=0.8)
    ax2.axhline(y=results.avg_pnl, color='blue', linewidth=1, linestyle='--',
               label=f'Media: ${results.avg_pnl:.2f}')
    ax2.set_title('P&L per Trade', fontweight='bold', fontsize=12)
    ax2.set_xlabel('Trade #')
    ax2.set_ylabel('P&L ($)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 3. Statistiche
    ax3 = axes[2]
    ax3.axis('off')

    exit_reasons = {}
    for t in results.trades:
        exit_reasons[t.exit_reason] = exit_reasons.get(t.exit_reason, 0) + 1

    exit_str = ", ".join(f"{k}: {v}" for k, v in sorted(exit_reasons.items()))

    stats = f"""
    RISULTATI BACKTEST IRON CONDOR
    {'═' * 50}

    Trades totali:       {results.n_trades}
    Win rate:            {results.win_rate*100:.1f}%

    P&L totale:          ${results.total_pnl:.2f}
    P&L medio:           ${results.avg_pnl:.2f}
    Media vincite:       ${results.avg_winner:.2f}
    Media perdite:       ${results.avg_loser:.2f}

    Max drawdown:        ${results.max_drawdown:.2f}
    Profit factor:       {results.profit_factor:.2f}
    Sharpe ratio:        {results.sharpe_ratio:.2f}

    Giorni medi:         {results.avg_days_held:.1f}
    Exit reasons:        {exit_str}
    """

    ax3.text(0.05, 0.95, stats, transform=ax3.transAxes,
            fontsize=11, verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Grafico salvato in: {save_path}")

    plt.show()


# =============================================================================
# ESEMPIO D'USO
# =============================================================================

if __name__ == "__main__":
    from scipy.stats import norm

    print("=" * 60)
    print("BACKTEST IRON CONDOR - 10 anni simulati")
    print("=" * 60)

    # Genera prezzi sintetici (10 anni)
    prices = generate_synthetic_prices(
        S0=4000, n_days=2520, annual_return=0.08,
        annual_vol=0.16, seed=42
    )

    print(f"\nDati: {len(prices)} giorni, da ${prices[0]:.0f} a ${prices[-1]:.0f}")

    # Configurazione standard
    config = BacktestConfig(
        entry_dte=45,
        delta_short_strike=0.16,
        wing_width=50,
        profit_target_pct=0.50,
        stop_loss_pct=2.0,
        exit_dte=21,
        iv_to_use=0.16,
    )

    print(f"\nConfigurazione:")
    print(f"  Entry: {config.entry_dte} DTE, delta {config.delta_short_strike}")
    print(f"  Ali: ${config.wing_width}")
    print(f"  Target: {config.profit_target_pct*100:.0f}% del credito")
    print(f"  Stop: {config.stop_loss_pct:.0f}x il credito")
    print(f"  Exit: {config.exit_dte} DTE")

    # Esegui backtest
    results = run_backtest(prices, config)

    print(f"\n{'─' * 40}")
    print(f"RISULTATI:")
    print(f"  Trades: {results.n_trades}")
    print(f"  Win rate: {results.win_rate*100:.1f}%")
    print(f"  P&L totale: ${results.total_pnl:.2f}")
    print(f"  P&L medio: ${results.avg_pnl:.2f}")
    print(f"  Profit factor: {results.profit_factor:.2f}")
    print(f"  Max drawdown: ${results.max_drawdown:.2f}")
    print(f"  Sharpe ratio: {results.sharpe_ratio:.2f}")

    # Visualizza
    plot_backtest_results(results, prices)

    # Confronto con parametri diversi
    print(f"\n{'═' * 60}")
    print("CONFRONTO: diversi livelli di delta")
    print(f"{'═' * 60}")

    for delta in [0.10, 0.16, 0.20, 0.25]:
        config_test = BacktestConfig(delta_short_strike=delta)
        res = run_backtest(prices, config_test)
        print(f"  Delta {delta:.2f}: WR={res.win_rate*100:.0f}%, "
              f"Avg=${res.avg_pnl:.2f}, PF={res.profit_factor:.2f}, "
              f"Sharpe={res.sharpe_ratio:.2f}")
