"""
payoff.py - Generazione di grafici payoff per combinazioni di opzioni
=====================================================================
Companion code per "Trading con le Opzioni - Strategie Operative"
di Pierpaolo Marturano (Core Matrix S.r.l.)

Supporta: singole opzioni, spread verticali, iron condor, butterfly,
straddle, strangle, covered call, e qualsiasi combinazione custom.

Requisiti: numpy, matplotlib
Compatibile con Python 3.10+
"""

import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Optional


@dataclass
class Leg:
    """Singola gamba di una strategia con opzioni."""
    option_type: str   # "call" o "put"
    strike: float
    premium: float
    quantity: int      # positivo = long, negativo = short

    def payoff_at_expiry(self, S: np.ndarray) -> np.ndarray:
        """Calcola il payoff a scadenza per un array di prezzi."""
        if self.option_type == "call":
            intrinsic = np.maximum(S - self.strike, 0)
        else:
            intrinsic = np.maximum(self.strike - S, 0)

        return self.quantity * (intrinsic - self.premium)


class Strategy:
    """Strategia composta da una o più gambe di opzioni."""

    def __init__(self, name: str = "Custom Strategy"):
        self.name = name
        self.legs: list[Leg] = []

    def add_leg(self, option_type: str, strike: float, premium: float,
                quantity: int) -> "Strategy":
        """Aggiunge una gamba alla strategia. Restituisce self per chaining."""
        self.legs.append(Leg(option_type, strike, premium, quantity))
        return self

    def total_payoff(self, S: np.ndarray) -> np.ndarray:
        """Calcola il payoff totale della strategia."""
        return sum(leg.payoff_at_expiry(S) for leg in self.legs)

    def max_profit(self, S: np.ndarray) -> float:
        """Profitto massimo della strategia."""
        return float(np.max(self.total_payoff(S)))

    def max_loss(self, S: np.ndarray) -> float:
        """Perdita massima della strategia."""
        return float(np.min(self.total_payoff(S)))

    def breakevens(self, S: np.ndarray) -> list[float]:
        """Calcola i punti di breakeven."""
        payoff = self.total_payoff(S)
        sign_changes = np.where(np.diff(np.sign(payoff)))[0]

        breakevens = []
        for idx in sign_changes:
            # Interpolazione lineare per trovare lo zero esatto
            x1, x2 = S[idx], S[idx + 1]
            y1, y2 = payoff[idx], payoff[idx + 1]
            if y2 != y1:
                be = x1 - y1 * (x2 - x1) / (y2 - y1)
                breakevens.append(float(be))
        return breakevens

    def plot(self, S_range: Optional[tuple[float, float]] = None,
             figsize: tuple[float, float] = (10, 6),
             save_path: Optional[str] = None) -> None:
        """
        Genera il grafico del payoff a scadenza.

        Parameters
        ----------
        S_range : tuple - (min, max) per l'asse X. Auto se None.
        figsize : tuple - dimensione del grafico
        save_path : str - percorso per salvare (opzionale)
        """
        # Determina range automaticamente
        if S_range is None:
            strikes = [leg.strike for leg in self.legs]
            center = np.mean(strikes)
            spread = max(strikes) - min(strikes)
            margin = max(spread * 0.5, center * 0.15)
            S_range = (center - margin, center + margin)

        S = np.linspace(S_range[0], S_range[1], 1000)
        payoff = self.total_payoff(S)

        fig, ax = plt.subplots(figsize=figsize)

        # Area profitto/perdita
        ax.fill_between(S, payoff, 0, where=(payoff >= 0),
                       color='#10b981', alpha=0.15, label='Profitto')
        ax.fill_between(S, payoff, 0, where=(payoff < 0),
                       color='#ef4444', alpha=0.15, label='Perdita')

        # Linea payoff
        ax.plot(S, payoff, color='#1a1a2e', linewidth=2)

        # Linea zero
        ax.axhline(y=0, color='gray', linewidth=0.8, linestyle='-')

        # Breakeven
        breakevens = self.breakevens(S)
        for be in breakevens:
            ax.axvline(x=be, color='#f59e0b', linewidth=1, linestyle='--', alpha=0.7)
            ax.annotate(f'BE: ${be:.1f}', xy=(be, 0),
                       xytext=(be, self.max_profit(S) * 0.1),
                       fontsize=9, ha='center', color='#f59e0b')

        # Strike lines
        for leg in self.legs:
            ax.axvline(x=leg.strike, color='gray', linewidth=0.5,
                      linestyle=':', alpha=0.5)

        # Annotazioni
        max_p = self.max_profit(S)
        max_l = self.max_loss(S)
        ax.set_title(f'{self.name}', fontsize=14, fontweight='bold')
        ax.set_xlabel('Prezzo del sottostante a scadenza ($)', fontsize=11)
        ax.set_ylabel('Profitto / Perdita ($)', fontsize=11)

        info_text = f'Max Profit: ${max_p:.0f}\nMax Loss: ${max_l:.0f}'
        if breakevens:
            be_str = ', '.join(f'${be:.1f}' for be in breakevens)
            info_text += f'\nBreakeven: {be_str}'

        ax.text(0.02, 0.98, info_text, transform=ax.transAxes,
               fontsize=9, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        ax.legend(loc='upper right', fontsize=9)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(S_range)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Grafico salvato in: {save_path}")

        plt.show()


# =============================================================================
# STRATEGIE PRE-CONFIGURATE
# =============================================================================

def iron_condor(put_long_K: float, put_short_K: float,
                call_short_K: float, call_long_K: float,
                net_credit: float) -> Strategy:
    """
    Crea un Iron Condor.

    Parameters
    ----------
    put_long_K : strike della put comprata (più basso)
    put_short_K : strike della put venduta
    call_short_K : strike della call venduta
    call_long_K : strike della call comprata (più alto)
    net_credit : credito netto incassato
    """
    width_put = put_short_K - put_long_K
    width_call = call_long_K - call_short_K

    # Distribuzione approssimata del premio tra le gambe
    put_spread_credit = net_credit * 0.5
    call_spread_credit = net_credit * 0.5

    s = Strategy(f"Iron Condor {put_short_K}/{call_short_K}")
    s.add_leg("put", put_long_K, put_spread_credit * 0.2, 1)   # long put
    s.add_leg("put", put_short_K, put_spread_credit * 0.8, -1)  # short put
    s.add_leg("call", call_short_K, call_spread_credit * 0.8, -1)  # short call
    s.add_leg("call", call_long_K, call_spread_credit * 0.2, 1)   # long call

    # Override con credito netto esatto
    s.legs = []
    s.add_leg("put", put_long_K, 0, 1)
    s.add_leg("put", put_short_K, 0, -1)
    s.add_leg("call", call_short_K, 0, -1)
    s.add_leg("call", call_long_K, 0, 1)

    # Aggiusto usando il fatto che il payoff netto = credito quando tutto OTM
    # Creiamo gambe con premium=0 e aggiungiamo il credito come offset
    class IronCondorStrategy(Strategy):
        def total_payoff(self, S):
            base = super().total_payoff(S)
            return base + net_credit

    ic = IronCondorStrategy(f"Iron Condor {put_short_K}/{call_short_K}")
    ic.add_leg("put", put_long_K, 0, 1)
    ic.add_leg("put", put_short_K, 0, -1)
    ic.add_leg("call", call_short_K, 0, -1)
    ic.add_leg("call", call_long_K, 0, -1 + 2)  # long call

    # Approccio più semplice: payoff = credit + intrinsic netto
    result = Strategy(f"Iron Condor {put_short_K}/{call_short_K}")
    # Usiamo le gambe con premium che rende il payoff corretto
    # Payoff IC = net_credit - max(0, put_short_K - S) + max(0, put_long_K - S)
    #           - max(0, S - call_short_K) + max(0, S - call_long_K)
    # Con premium=0, il payoff di una short put è: -max(0, K-S) + premium_ricevuto
    # Trick: settiamo premium della short put = parte del credito
    result.add_leg("put", put_long_K, 0, 1)
    result.add_leg("put", put_short_K, net_credit / 2, -1)
    result.add_leg("call", call_short_K, net_credit / 2, -1)
    result.add_leg("call", call_long_K, 0, 1)
    return result


def bull_put_spread(short_K: float, long_K: float, net_credit: float) -> Strategy:
    """Crea un Bull Put Spread (credit put spread)."""
    s = Strategy(f"Bull Put Spread {short_K}/{long_K}")
    s.add_leg("put", short_K, net_credit, -1)
    s.add_leg("put", long_K, 0, 1)
    return s


def bear_call_spread(short_K: float, long_K: float, net_credit: float) -> Strategy:
    """Crea un Bear Call Spread (credit call spread)."""
    s = Strategy(f"Bear Call Spread {short_K}/{long_K}")
    s.add_leg("call", short_K, net_credit, -1)
    s.add_leg("call", long_K, 0, 1)
    return s


def butterfly(low_K: float, mid_K: float, high_K: float,
              net_debit: float, option_type: str = "call") -> Strategy:
    """Crea una Butterfly Spread."""
    s = Strategy(f"Butterfly {low_K}/{mid_K}/{high_K}")
    s.add_leg(option_type, low_K, net_debit, 1)
    s.add_leg(option_type, mid_K, 0, -2)
    s.add_leg(option_type, high_K, 0, 1)
    return s


def straddle(strike: float, total_premium: float, direction: str = "long") -> Strategy:
    """Crea uno Straddle (long o short)."""
    qty = 1 if direction == "long" else -1
    s = Strategy(f"{'Long' if qty > 0 else 'Short'} Straddle {strike}")
    s.add_leg("call", strike, total_premium / 2, qty)
    s.add_leg("put", strike, total_premium / 2, qty)
    return s


def strangle(put_K: float, call_K: float, total_premium: float,
             direction: str = "long") -> Strategy:
    """Crea uno Strangle (long o short)."""
    qty = 1 if direction == "long" else -1
    s = Strategy(f"{'Long' if qty > 0 else 'Short'} Strangle {put_K}/{call_K}")
    s.add_leg("put", put_K, total_premium / 2, qty)
    s.add_leg("call", call_K, total_premium / 2, qty)
    return s


def covered_call(stock_price: float, call_strike: float,
                 call_premium: float) -> Strategy:
    """
    Crea una Covered Call.
    Nota: il sottostante è modellato come una call con strike 0.
    """
    s = Strategy(f"Covered Call (stock@{stock_price}, call@{call_strike})")
    # Long stock = long call con strike 0 e premium = stock_price
    s.add_leg("call", 0, stock_price, 1)
    # Short call
    s.add_leg("call", call_strike, call_premium, -1)
    return s


# =============================================================================
# ESEMPIO D'USO
# =============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("PAYOFF DIAGRAMS - Esempi")
    print("=" * 60)

    # Esempio 1: Iron Condor su SPX
    print("\n1. Iron Condor SPX 5400/5450/5550/5600, credito $4.50")
    ic = iron_condor(5400, 5450, 5550, 5600, 4.50)
    S = np.linspace(5350, 5650, 1000)
    print(f"   Max Profit: ${ic.max_profit(S):.2f}")
    print(f"   Max Loss: ${ic.max_loss(S):.2f}")
    print(f"   Breakeven: {['${:.1f}'.format(b) for b in ic.breakevens(S)]}")
    ic.plot()

    # Esempio 2: Bull Put Spread
    print("\n2. Bull Put Spread 5400/5350, credito $2.00")
    bps = bull_put_spread(5400, 5350, 2.00)
    bps.plot(S_range=(5300, 5500))

    # Esempio 3: Short Strangle
    print("\n3. Short Strangle 5400/5600, premio totale $8.00")
    ss = strangle(5400, 5600, 8.00, direction="short")
    ss.plot(S_range=(5300, 5700))

    # Esempio 4: Custom strategy
    print("\n4. Strategia custom: Broken Wing Butterfly")
    bwb = Strategy("Broken Wing Butterfly Put")
    bwb.add_leg("put", 5400, 1.0, 1)    # long 1 put
    bwb.add_leg("put", 5450, 3.5, -2)   # short 2 put
    bwb.add_leg("put", 5500, 6.0, 1)    # long 1 put (ala larga)
    bwb.plot(S_range=(5350, 5550))
