"""
vol_surface.py - Visualizzazione della superficie di volatilità implicita
=========================================================================
Companion code per "Trading con le Opzioni - Strategie Operative"
di Pierpaolo Marturano (Core Matrix S.r.l.)

Costruisce e visualizza:
- Volatility smile per singola scadenza
- Term structure della volatilità
- Superficie completa 3D (strike x scadenza x IV)
- Skew analysis

Requisiti: numpy, scipy, matplotlib
Compatibile con Python 3.10+

Nota: utilizza dati sintetici per dimostrazione. Per dati reali,
sostituire con feed da broker (IBKR TWS API) o data provider.
"""

import numpy as np
from scipy.interpolate import griddata, CubicSpline
from scipy.stats import norm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def generate_synthetic_surface(S: float = 5500, r: float = 0.045,
                                atm_vol: float = 0.16,
                                skew_slope: float = -0.12,
                                term_slope: float = 0.02,
                                smile_curvature: float = 0.05) -> dict:
    """
    Genera una superficie di volatilità sintetica realistica.

    Il modello incorpora:
    - Skew negativo (put più care delle call per equity)
    - Term structure crescente (contango tipico)
    - Smile (curvatura) più pronunciata su scadenze brevi

    Parameters
    ----------
    S : float - Prezzo del sottostante
    r : float - Tasso risk-free
    atm_vol : float - IV ATM di riferimento
    skew_slope : float - Pendenza dello skew (negativo per equity)
    term_slope : float - Pendenza della term structure
    smile_curvature : float - Curvatura del smile

    Returns
    -------
    dict con strikes, expirations, iv_matrix
    """
    # Strike come moneyness (K/S)
    moneyness = np.linspace(0.85, 1.15, 25)
    strikes = S * moneyness

    # Scadenze in giorni
    expirations_days = np.array([7, 14, 21, 30, 45, 60, 90, 120, 180, 365])
    expirations_years = expirations_days / 365

    # Costruisci la superficie
    iv_matrix = np.zeros((len(expirations_days), len(moneyness)))

    for i, T in enumerate(expirations_years):
        for j, m in enumerate(moneyness):
            # Log-moneyness
            log_m = np.log(m)

            # ATM vol per questa scadenza (term structure)
            atm_T = atm_vol + term_slope * np.sqrt(T)

            # Skew (lineare in log-moneyness, più forte su brevi scadenze)
            skew_factor = skew_slope / np.sqrt(T + 0.01)
            skew_contribution = skew_factor * log_m

            # Smile (quadratico, più pronunciato su brevi scadenze)
            smile_factor = smile_curvature / (T + 0.05)
            smile_contribution = smile_factor * log_m**2

            # IV finale
            iv = atm_T + skew_contribution + smile_contribution
            iv_matrix[i, j] = max(iv, 0.05)  # floor al 5%

    return {
        "strikes": strikes,
        "moneyness": moneyness,
        "expirations_days": expirations_days,
        "expirations_years": expirations_years,
        "iv_matrix": iv_matrix,
        "spot": S,
    }


def plot_smile(surface: dict, expiry_idx: int | list[int] = None,
               save_path: str | None = None) -> None:
    """
    Visualizza il volatility smile per una o più scadenze.
    """
    if expiry_idx is None:
        expiry_idx = [2, 4, 6, 9]  # 21, 45, 90, 365 giorni

    if isinstance(expiry_idx, int):
        expiry_idx = [expiry_idx]

    fig, ax = plt.subplots(figsize=(10, 6))

    colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(expiry_idx)))

    for idx, color in zip(expiry_idx, colors):
        days = surface["expirations_days"][idx]
        iv_pct = surface["iv_matrix"][idx, :] * 100
        moneyness_pct = surface["moneyness"] * 100

        ax.plot(moneyness_pct, iv_pct, linewidth=2, color=color,
               label=f'{days} DTE', marker='o', markersize=3)

    ax.axvline(x=100, color='gray', linewidth=0.8, linestyle='--', alpha=0.5)
    ax.set_title('Volatility Smile per Scadenza', fontweight='bold', fontsize=13)
    ax.set_xlabel('Moneyness (K/S × 100)', fontsize=11)
    ax.set_ylabel('Implied Volatility (%)', fontsize=11)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def plot_term_structure(surface: dict, moneyness_levels: list[float] = None,
                        save_path: str | None = None) -> None:
    """
    Visualizza la term structure della volatilità.
    """
    if moneyness_levels is None:
        moneyness_levels = [0.90, 0.95, 1.00, 1.05, 1.10]

    fig, ax = plt.subplots(figsize=(10, 6))

    colors = ['#ef4444', '#f59e0b', '#10b981', '#3b82f6', '#8b5cf6']

    for m, color in zip(moneyness_levels, colors):
        # Trova l'indice più vicino
        idx = np.argmin(np.abs(surface["moneyness"] - m))
        iv_pct = surface["iv_matrix"][:, idx] * 100

        label = f'K/S = {m:.0%}'
        if m == 1.0:
            label = 'ATM (K/S = 100%)'

        ax.plot(surface["expirations_days"], iv_pct, linewidth=2,
               color=color, label=label, marker='s', markersize=5)

    ax.set_title('Term Structure della Volatilità Implicita', fontweight='bold', fontsize=13)
    ax.set_xlabel('Giorni a Scadenza (DTE)', fontsize=11)
    ax.set_ylabel('Implied Volatility (%)', fontsize=11)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 370)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def plot_surface_3d(surface: dict, save_path: str | None = None) -> None:
    """
    Visualizza la superficie di volatilità completa in 3D.
    """
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Crea meshgrid
    X, Y = np.meshgrid(surface["moneyness"] * 100,
                       surface["expirations_days"])
    Z = surface["iv_matrix"] * 100

    # Plot superficie
    surf = ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8,
                          edgecolor='none', antialiased=True)

    ax.set_title('Superficie di Volatilità Implicita', fontweight='bold', fontsize=13)
    ax.set_xlabel('Moneyness (K/S × 100)')
    ax.set_ylabel('DTE')
    ax.set_zlabel('IV (%)')
    ax.view_init(elev=25, azim=-60)

    fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10, label='IV (%)')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def plot_skew_analysis(surface: dict, save_path: str | None = None) -> None:
    """
    Analisi dello skew: differenza di IV tra put OTM e call OTM.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 1. Skew (25-delta put IV - 25-delta call IV) per scadenza
    ax1 = axes[0]

    # Approssimazione: 25-delta ≈ 5% OTM
    put_25d_idx = np.argmin(np.abs(surface["moneyness"] - 0.95))
    call_25d_idx = np.argmin(np.abs(surface["moneyness"] - 1.05))
    atm_idx = np.argmin(np.abs(surface["moneyness"] - 1.00))

    skew = (surface["iv_matrix"][:, put_25d_idx] -
            surface["iv_matrix"][:, call_25d_idx]) * 100

    ax1.bar(range(len(surface["expirations_days"])), skew,
           color='#0d5c4d', alpha=0.7)
    ax1.set_xticks(range(len(surface["expirations_days"])))
    ax1.set_xticklabels(surface["expirations_days"], fontsize=9)
    ax1.set_title('Skew (25Δ Put IV - 25Δ Call IV)', fontweight='bold')
    ax1.set_xlabel('DTE')
    ax1.set_ylabel('Skew (punti %)')
    ax1.grid(True, alpha=0.3)

    # 2. Risk Reversal normalizzato
    ax2 = axes[1]

    # Butterfly: (25d put + 25d call) / 2 - ATM
    butterfly = ((surface["iv_matrix"][:, put_25d_idx] +
                  surface["iv_matrix"][:, call_25d_idx]) / 2 -
                 surface["iv_matrix"][:, atm_idx]) * 100

    ax2.plot(surface["expirations_days"], butterfly, 'o-',
            color='#8b5cf6', linewidth=2, markersize=6)
    ax2.set_title('Butterfly Spread (curvatura del smile)', fontweight='bold')
    ax2.set_xlabel('DTE')
    ax2.set_ylabel('Butterfly (punti %)')
    ax2.grid(True, alpha=0.3)
    ax2.axhline(y=0, color='gray', linewidth=0.8, linestyle='--')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def iv_rank(current_iv: float, iv_history: np.ndarray) -> float:
    """
    Calcola l'IV Rank (percentile della IV corrente su 1 anno).

    IV Rank = (IV_corrente - IV_min) / (IV_max - IV_min)
    """
    iv_min = np.min(iv_history)
    iv_max = np.max(iv_history)
    if iv_max == iv_min:
        return 0.5
    return (current_iv - iv_min) / (iv_max - iv_min)


def iv_percentile(current_iv: float, iv_history: np.ndarray) -> float:
    """
    Calcola l'IV Percentile (% di giorni con IV inferiore alla corrente).
    """
    return np.mean(iv_history < current_iv)


# =============================================================================
# ESEMPIO D'USO
# =============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("SUPERFICIE DI VOLATILITÀ - Visualizzazione")
    print("=" * 60)

    # Genera superficie sintetica
    surface = generate_synthetic_surface(
        S=5500, atm_vol=0.16, skew_slope=-0.12,
        term_slope=0.02, smile_curvature=0.05
    )

    print(f"\nSuperficie generata:")
    print(f"  Spot: ${surface['spot']}")
    print(f"  Strike range: ${surface['strikes'][0]:.0f} - ${surface['strikes'][-1]:.0f}")
    print(f"  Scadenze: {surface['expirations_days']} giorni")
    print(f"  ATM IV (30 DTE): {surface['iv_matrix'][3, 12]*100:.1f}%")

    # Visualizzazioni
    print("\n1. Volatility Smile...")
    plot_smile(surface)

    print("\n2. Term Structure...")
    plot_term_structure(surface)

    print("\n3. Superficie 3D...")
    plot_surface_3d(surface)

    print("\n4. Skew Analysis...")
    plot_skew_analysis(surface)

    # IV Rank e Percentile
    print("\n" + "=" * 60)
    print("IV RANK & PERCENTILE")
    print("=" * 60)

    # Simula 252 giorni di storia IV
    rng = np.random.default_rng(42)
    iv_history = 0.16 + 0.04 * np.sin(np.linspace(0, 4*np.pi, 252)) + \
                 0.02 * rng.standard_normal(252)
    iv_history = np.maximum(iv_history, 0.08)

    current = 0.18
    rank = iv_rank(current, iv_history)
    pctile = iv_percentile(current, iv_history)

    print(f"\n  IV corrente: {current*100:.1f}%")
    print(f"  IV min (1Y): {np.min(iv_history)*100:.1f}%")
    print(f"  IV max (1Y): {np.max(iv_history)*100:.1f}%")
    print(f"  IV Rank:       {rank*100:.1f}%")
    print(f"  IV Percentile: {pctile*100:.1f}%")
    print(f"\n  Interpretazione:")
    if rank > 0.5:
        print(f"  → IV elevata: favorevole per strategie di vendita (theta)")
    else:
        print(f"  → IV bassa: favorevole per strategie di acquisto (vega)")
