diff --git a/app/api/cointegration.py b/app/api/cointegration.py index f07cd469..6664f8a0 100644 --- a/app/api/cointegration.py +++ b/app/api/cointegration.py @@ -30,7 +30,7 @@ async def cointegration( fmp: FMPDep, redis: RedisDep, frequency: FMP.freq = Query( - FMP.freq.daily, + FMP.freq.DAILY, description="Price sampling frequency.", ), ) -> CointegrationResponse: diff --git a/docs/examples/heston_volatility_pricer.py b/docs/examples/heston_volatility_pricer.py index c6c36c08..3af03754 100644 --- a/docs/examples/heston_volatility_pricer.py +++ b/docs/examples/heston_volatility_pricer.py @@ -17,7 +17,7 @@ # Price an ATM call option at time to maturity 1.0 price = pricer.price( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=100.0, forward=100.0, ttm=1.0, diff --git a/docs/examples/wiener_volatility_pricer.py b/docs/examples/wiener_volatility_pricer.py index 472ad0ca..757f8676 100644 --- a/docs/examples/wiener_volatility_pricer.py +++ b/docs/examples/wiener_volatility_pricer.py @@ -8,7 +8,7 @@ # Price an ATM call option at time to maturity 1.0 price = pricer.price( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=100.0, forward=100.0, ttm=1.0, diff --git a/quantflow/data/deribit.py b/quantflow/data/deribit.py index da8a0b5c..fef57c4c 100644 --- a/quantflow/data/deribit.py +++ b/quantflow/data/deribit.py @@ -31,11 +31,11 @@ def parse_maturity(v: str) -> datetime: class InstrumentKind(enum.StrEnum): """Instrument kind for Deribit API.""" - future = enum.auto() - option = enum.auto() - spot = enum.auto() - future_combo = enum.auto() - option_combo = enum.auto() + FUTURE = enum.auto() + OPTION = enum.auto() + SPOT = enum.auto() + FUTURE_COMBO = enum.auto() + OPTION_COMBO = enum.auto() @dataclass @@ -158,18 +158,18 @@ async def volatility_surface_loader( ) if inverse: futures = await self.get_book_summary_by_currency( - currency=currency, kind=InstrumentKind.future + currency=currency, kind=InstrumentKind.FUTURE ) options = await self.get_book_summary_by_currency( - currency=currency, kind=InstrumentKind.option + currency=currency, kind=InstrumentKind.OPTION ) instruments = await self.get_instruments(currency=currency) else: futures = await self.get_book_summary_by_currency( - currency="usdc", kind=InstrumentKind.future, base=currency + currency="usdc", kind=InstrumentKind.FUTURE, base=currency ) options = await self.get_book_summary_by_currency( - currency="usdc", kind=InstrumentKind.option, base=currency + currency="usdc", kind=InstrumentKind.OPTION, base=currency ) instruments = await self.get_instruments(currency="usdc", base=currency) instrument_map = {i["instrument_name"]: i for i in instruments} @@ -226,9 +226,9 @@ async def volatility_surface_loader( utc=True, ).to_pydatetime(), option_type=( - OptionType.call + OptionType.CALL if meta["option_type"] == "call" - else OptionType.put + else OptionType.PUT ), bid=round_to_step(bid_, tick_size), ask=round_to_step(ask_, tick_size), diff --git a/quantflow/data/fmp.py b/quantflow/data/fmp.py index 4b62a901..406f0394 100644 --- a/quantflow/data/fmp.py +++ b/quantflow/data/fmp.py @@ -25,22 +25,22 @@ class FMP(AioHttpClient): class freq(StrEnum): """FMP historical frequencies""" - one_min = "1min" - five_min = "5min" - fifteen_min = "15min" - thirty_min = "30min" - one_hour = "1hour" - four_hour = "4hour" - daily = "daily" + ONE_MIN = "1min" + FIVE_MIN = "5min" + FIFTEEN_MIN = "15min" + THIRTY_MIN = "30min" + ONE_HOUR = "1hour" + FOUR_HOUR = "4hour" + DAILY = "daily" @classmethod def crate(cls, s: str | None) -> Self: if s is None: - return cls.daily + return cls.DAILY try: return cls(s) except ValueError: - return cls.daily + return cls.DAILY async def market_risk_premium(self) -> list[dict]: """Market risk premium""" @@ -204,7 +204,7 @@ async def prices( freq = self.freq.crate(frequency) path = ( "historical-price-eod/full" - if freq is self.freq.daily + if freq is self.freq.DAILY else f"historical-chart/{freq}" ) data = await self.get_path( diff --git a/quantflow/data/fred.py b/quantflow/data/fred.py index c4313fc6..56a5713a 100644 --- a/quantflow/data/fred.py +++ b/quantflow/data/fred.py @@ -22,13 +22,13 @@ class Fred(AioHttpClient): class freq(StrEnum): """Fred historical frequencies""" - d = "d" - w = "w" - bw = "bw" - m = "m" - q = "q" - sa = "sa" - a = "a" + D = "d" + W = "w" + BW = "bw" + M = "m" + Q = "q" + SA = "sa" + A = "a" async def categiories(self, **kw: Any) -> dict: """Get categories""" diff --git a/quantflow/data/yahoo.py b/quantflow/data/yahoo.py index cc66b37a..cdd89e9d 100644 --- a/quantflow/data/yahoo.py +++ b/quantflow/data/yahoo.py @@ -62,17 +62,17 @@ class Yahoo(HttpxClient): class freq(StrEnum): """Yahoo Finance chart intervals""" - one_min = "1m" - two_min = "2m" - five_min = "5m" - fifteen_min = "15m" - thirty_min = "30m" - one_hour = "1h" - one_day = "1d" - five_day = "5d" - one_week = "1wk" - one_month = "1mo" - three_month = "3mo" + ONE_MIN = "1m" + TWO_MIN = "2m" + FIVE_MIN = "5m" + FIFTEEN_MIN = "15m" + THIRTY_MIN = "30m" + ONE_HOUR = "1h" + ONE_DAY = "1d" + FIVE_DAY = "5d" + ONE_WEEK = "1wk" + ONE_MONTH = "1mo" + THREE_MONTH = "3mo" async def option_chain( self, @@ -160,8 +160,8 @@ def loader_from_chain( .replace(hour=20, tzinfo=timezone.utc) ) for option_type, contracts in ( - (OptionType.call, expiry.get("calls", [])), - (OptionType.put, expiry.get("puts", [])), + (OptionType.CALL, expiry.get("calls", [])), + (OptionType.PUT, expiry.get("puts", [])), ): for c in contracts: bid_ = c.get("bid") @@ -187,7 +187,7 @@ async def prices( *, interval: Annotated[ str | freq, Doc("Bar interval — use Yahoo.freq members or a raw string") - ] = freq.one_day, + ] = freq.ONE_DAY, from_date: Annotated[date | None, Doc("Start date (inclusive)")] = None, to_date: Annotated[date | None, Doc("End date (inclusive)")] = None, range: Annotated[ diff --git a/quantflow/options/inputs.py b/quantflow/options/inputs.py index dc4e1709..1d1ccf5a 100644 --- a/quantflow/options/inputs.py +++ b/quantflow/options/inputs.py @@ -18,25 +18,25 @@ class Side(enum.StrEnum): """Side of the market""" - bid = enum.auto() - ask = enum.auto() + BID = enum.auto() + ASK = enum.auto() class OptionType(enum.StrEnum): """Type of option""" - call = enum.auto() - put = enum.auto() + CALL = enum.auto() + PUT = enum.auto() def is_call(self) -> bool: - return self is OptionType.call + return self is OptionType.CALL def is_put(self) -> bool: - return self is OptionType.put + return self is OptionType.PUT def call_put(self) -> int: """Return 1 for call options and -1 for put options""" - return 1 if self is OptionType.call else -1 + return 1 if self is OptionType.CALL else -1 class OptionMetadata(BaseModel): @@ -65,9 +65,9 @@ def is_in_the_money(self, forward: Decimal) -> bool: class VolSecurityType(enum.StrEnum): """Type of security for the volatility surface""" - spot = enum.auto() - forward = enum.auto() - option = enum.auto() + SPOT = enum.auto() + FORWARD = enum.auto() + OPTION = enum.auto() class VolSurfaceSecurity(BaseModel): @@ -84,7 +84,7 @@ def forward(cls) -> Self: class DefaultVolSecurity(VolSurfaceSecurity): security_type: VolSecurityType = Field( - default=VolSecurityType.spot, + default=VolSecurityType.SPOT, description="Type of security for the volatility surface", ) @@ -93,22 +93,22 @@ def vol_surface_type(self) -> VolSecurityType: @classmethod def spot(cls) -> Self: - return cls(security_type=VolSecurityType.spot) + return cls(security_type=VolSecurityType.SPOT) @classmethod def forward(cls) -> Self: - return cls(security_type=VolSecurityType.forward) + return cls(security_type=VolSecurityType.FORWARD) @classmethod def option(cls) -> Self: - return cls(security_type=VolSecurityType.option) + return cls(security_type=VolSecurityType.OPTION) class SpotInput(PriceVolume): """Input data for a spot contract in the volatility surface""" security_type: VolSecurityType = Field( - default=VolSecurityType.spot, + default=VolSecurityType.SPOT, description="Type of security for the volatility surface", ) @@ -118,7 +118,7 @@ class ForwardInput(PriceVolume): maturity: datetime = Field(description="Expiry date of the forward contract") security_type: VolSecurityType = Field( - default=VolSecurityType.forward, + default=VolSecurityType.FORWARD, description="Type of security for the volatility surface", ) @@ -127,7 +127,7 @@ class OptionInput(PriceVolume, OptionMetadata): """Input data for an option in the volatility surface""" security_type: VolSecurityType = Field( - default=VolSecurityType.option, + default=VolSecurityType.OPTION, description="Type of security for the volatility surface", ) iv_bid: DecimalNumber | None = Field( diff --git a/quantflow/options/pricer.py b/quantflow/options/pricer.py index 4c5863c5..003e2ef6 100644 --- a/quantflow/options/pricer.py +++ b/quantflow/options/pricer.py @@ -103,7 +103,7 @@ def intrinsic_value(self) -> float: For a put option, the intrinsic value is non-negative when the moneyness is positive, i.e. when the strike is above the forward price. """ - if self.option_type == OptionType.call: + if self.option_type == OptionType.CALL: return max(0.0, self.parity) else: return max(0.0, -self.parity) @@ -118,7 +118,7 @@ def as_option_type( """Convert the option price to the given option type via put-call parity.""" if self.option_type == option_type: return self - if self.option_type == OptionType.call: + if self.option_type == OptionType.CALL: new_price = self.price - self.parity new_delta = self.delta - 1.0 else: @@ -161,7 +161,7 @@ def price( log_strike = float((strike_ / forward_).ln()) result = self.pricing.call_greeks(log_strike) return ModelOptionPrice( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=strike_, forward=forward_, log_strike=log_strike, diff --git a/quantflow/options/strategies/butterfly.py b/quantflow/options/strategies/butterfly.py index 5b7bfb1d..0b7679ec 100644 --- a/quantflow/options/strategies/butterfly.py +++ b/quantflow/options/strategies/butterfly.py @@ -17,7 +17,7 @@ def _option_type_for_log_strike(mid_log_strike: float) -> OptionType: Calls for body above ATM, puts for body below ATM, calls at ATM. """ - return OptionType.put if mid_log_strike < 0 else OptionType.call + return OptionType.PUT if mid_log_strike < 0 else OptionType.CALL class Butterfly(Strategy, frozen=True): diff --git a/quantflow/options/strategies/calendar_spread.py b/quantflow/options/strategies/calendar_spread.py index 9af0146c..2a571f76 100644 --- a/quantflow/options/strategies/calendar_spread.py +++ b/quantflow/options/strategies/calendar_spread.py @@ -68,7 +68,7 @@ def call( quantity: Number = 1.0, ) -> Self: return cls.create( - strike, near_maturity, far_maturity, OptionType.call, quantity + strike, near_maturity, far_maturity, OptionType.CALL, quantity ) @classmethod @@ -79,4 +79,4 @@ def put( far_maturity: datetime, quantity: Number = 1.0, ) -> Self: - return cls.create(strike, near_maturity, far_maturity, OptionType.put, quantity) + return cls.create(strike, near_maturity, far_maturity, OptionType.PUT, quantity) diff --git a/quantflow/options/strategies/spread.py b/quantflow/options/strategies/spread.py index 50c43ab9..26535324 100644 --- a/quantflow/options/strategies/spread.py +++ b/quantflow/options/strategies/spread.py @@ -39,7 +39,7 @@ def call( legs=( StrategyLeg( meta=OptionMetadata( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=low, maturity=maturity, ), @@ -47,7 +47,7 @@ def call( ), StrategyLeg( meta=OptionMetadata( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=high, maturity=maturity, ), @@ -74,7 +74,7 @@ def put( legs=( StrategyLeg( meta=OptionMetadata( - option_type=OptionType.put, + option_type=OptionType.PUT, strike=high, maturity=maturity, ), @@ -82,7 +82,7 @@ def put( ), StrategyLeg( meta=OptionMetadata( - option_type=OptionType.put, + option_type=OptionType.PUT, strike=low, maturity=maturity, ), diff --git a/quantflow/options/strategies/straddle.py b/quantflow/options/strategies/straddle.py index d47e0902..1c3d94e0 100644 --- a/quantflow/options/strategies/straddle.py +++ b/quantflow/options/strategies/straddle.py @@ -28,7 +28,7 @@ def create(cls, strike: Number, maturity: datetime, quantity: Number = 1.0) -> S legs=( StrategyLeg( meta=OptionMetadata( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=strike_, maturity=maturity, ), @@ -36,7 +36,7 @@ def create(cls, strike: Number, maturity: datetime, quantity: Number = 1.0) -> S ), StrategyLeg( meta=OptionMetadata( - option_type=OptionType.put, + option_type=OptionType.PUT, strike=strike_, maturity=maturity, ), diff --git a/quantflow/options/strategies/strangle.py b/quantflow/options/strategies/strangle.py index 6f0d97aa..23f18a0f 100644 --- a/quantflow/options/strategies/strangle.py +++ b/quantflow/options/strategies/strangle.py @@ -37,7 +37,7 @@ def from_strikes( legs=( StrategyLeg( meta=OptionMetadata( - option_type=OptionType.put, + option_type=OptionType.PUT, strike=put, maturity=maturity, ), @@ -45,7 +45,7 @@ def from_strikes( ), StrategyLeg( meta=OptionMetadata( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=call, maturity=maturity, ), diff --git a/quantflow/options/surface.py b/quantflow/options/surface.py index 3fae2942..12b0043d 100644 --- a/quantflow/options/surface.py +++ b/quantflow/options/surface.py @@ -68,18 +68,18 @@ class OptionSelection(enum.Enum): for calculating implied volatility and other operations """ - best = enum.auto() + BEST = enum.auto() """Select the OTM option but blend call and put implied volatilities near the money. The blending weight transitions linearly from 50/50 at moneyness 0 to pure OTM at the moneyness threshold.""" - otm = enum.auto() + OTM = enum.auto() """Select Out of the Money options only, where their intrinsic value is zero""" - call = enum.auto() + CALL = enum.auto() """Select the call options only""" - put = enum.auto() + PUT = enum.auto() """Select the put options only""" - all = enum.auto() + ALL = enum.auto() """Select all options regardless of their moneyness""" @@ -146,7 +146,7 @@ class OptionPrice(BaseModel): ttm: float = Field(default=0, description="Time to maturity in years") iv: float = Field(default=0, description="Implied volatility of the option") side: Side = Field( - default=Side.bid, description="Side of the market for the option price" + default=Side.BID, description="Side of the market for the option price" ) converged: bool = Field( default=False, @@ -464,7 +464,7 @@ def options_iter( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.all, + ] = OptionSelection.ALL, ) -> Iterator[OptionPrices[S]]: """Iterator over option prices for the strike @@ -474,23 +474,23 @@ def options_iter( case only the Out of the Money options are included in the iteration. """ match select: - case OptionSelection.otm: + case OptionSelection.OTM: if self.call and not self.call.is_in_the_money(forward): yield self.call elif self.put and not self.put.is_in_the_money(forward): yield self.put - case OptionSelection.best: + case OptionSelection.BEST: if self.call and not self.call.is_in_the_money(forward): yield self.call elif self.put and not self.put.is_in_the_money(forward): yield self.put - case OptionSelection.call: + case OptionSelection.CALL: if self.call: yield self.call - case OptionSelection.put: + case OptionSelection.PUT: if self.put: yield self.put - case OptionSelection.all: + case OptionSelection.ALL: if self.call: yield self.call if self.put: @@ -503,7 +503,7 @@ def securities( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.all, + ] = OptionSelection.ALL, converged: Annotated[ bool, Doc( @@ -524,7 +524,7 @@ def option_prices( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, initial_vol: Annotated[ float, Doc("Initial volatility for the root finding algorithm") ] = INITIAL_VOL, @@ -599,7 +599,7 @@ def option_prices( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, initial_vol: Annotated[ float, Doc("Initial volatility for the root finding algorithm") ] = INITIAL_VOL, @@ -622,7 +622,7 @@ def securities( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.all, + ] = OptionSelection.ALL, converged: Annotated[ bool, Doc( @@ -641,7 +641,7 @@ def option_securities( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.all, + ] = OptionSelection.ALL, converged: Annotated[ bool, Doc( @@ -839,7 +839,7 @@ def securities( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.all, + ] = OptionSelection.ALL, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -868,7 +868,7 @@ def inputs( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.all, + ] = OptionSelection.ALL, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -914,7 +914,7 @@ def option_prices( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -955,7 +955,7 @@ def option_list( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -976,7 +976,7 @@ def bs( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -1015,7 +1015,7 @@ def calc_bs_prices( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -1033,7 +1033,7 @@ def options_df( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -1058,7 +1058,7 @@ def as_array( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -1110,7 +1110,7 @@ def as_array( def reset_convergence(self) -> None: """Reset the convergence flag for all options in the surface""" - for option in self.option_prices(select=OptionSelection.all): + for option in self.option_prices(select=OptionSelection.ALL): option.converged = False def disable_outliers( @@ -1160,7 +1160,7 @@ def plot( ] = None, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, **kwargs: Any, ) -> Any: """Plot the volatility surface""" @@ -1172,7 +1172,7 @@ def plot3d( *, select: Annotated[ OptionSelection, Doc("Option selection method") - ] = OptionSelection.best, + ] = OptionSelection.BEST, index: Annotated[ int | None, Doc("Index of the cross section to use, if None use all") ] = None, @@ -1230,8 +1230,8 @@ def add_option( option = OptionPrices( security=security, meta=meta, - bid=OptionPrice(price=normalize_decimal(bid), meta=meta, side=Side.bid), - ask=OptionPrice(price=normalize_decimal(ask), meta=meta, side=Side.ask), + bid=OptionPrice(price=normalize_decimal(bid), meta=meta, side=Side.BID), + ask=OptionPrice(price=normalize_decimal(ask), meta=meta, side=Side.ASK), open_interest=normalize_decimal(open_interest), volume=normalize_decimal(volume), ) @@ -1332,7 +1332,7 @@ def add_spot( volume: Annotated[Decimal, Doc("Volume for the spot")] = ZERO, ) -> None: """Add a spot to the volatility surface loader""" - if security.vol_surface_type() != VolSecurityType.spot: + if security.vol_surface_type() != VolSecurityType.SPOT: raise ValueError("Security is not a spot") self.spot = SpotPrice( security=security, @@ -1352,7 +1352,7 @@ def add_forward( volume: Annotated[Decimal, Doc("Volume for the forward")] = ZERO, ) -> None: """Add a forward to the volatility surface loader""" - if security.vol_surface_type() != VolSecurityType.forward: + if security.vol_surface_type() != VolSecurityType.FORWARD: raise ValueError("Security is not a forward") self.get_or_create_maturity(maturity=maturity).forward = FwdPrice( security=security, @@ -1376,7 +1376,7 @@ def add_option( inverse: Annotated[bool, Doc("Whether the option is an inverse option")] = True, ) -> None: """Add an option to the volatility surface loader""" - if security.vol_surface_type() != VolSecurityType.option: + if security.vol_surface_type() != VolSecurityType.OPTION: raise ValueError("Security is not an option") if self.exclude_volume is not None and volume <= self.exclude_volume: return diff --git a/quantflow/rates/__init__.py b/quantflow/rates/__init__.py index 2d262134..fb33a193 100644 --- a/quantflow/rates/__init__.py +++ b/quantflow/rates/__init__.py @@ -5,6 +5,10 @@ from .calibration import YieldCurveCalibration from .cir import CIRCurve from .interest_rate import Rate +from .interpolated import ( + InterpolatedYieldCurve, + InterpolationType, +) from .nelson_siegel import NelsonSiegel from .no_discount import NoDiscount from .vasicek import VasicekCurve @@ -15,6 +19,8 @@ "YieldCurveCalibration", "NoDiscount", "CIRCurve", + "InterpolatedYieldCurve", + "InterpolationType", "NelsonSiegel", "VasicekCurve", "AnyYieldCurve", @@ -22,8 +28,10 @@ ] AnyYieldCurve = Annotated[ - Union[NoDiscount, CIRCurve, NelsonSiegel, VasicekCurve], + Union[NoDiscount, CIRCurve, InterpolatedYieldCurve, NelsonSiegel, VasicekCurve], Field(discriminator="curve_type"), ] -YieldCurve.register_curve_types(NoDiscount, CIRCurve, NelsonSiegel, VasicekCurve) +YieldCurve.register_curve_types( + NoDiscount, CIRCurve, InterpolatedYieldCurve, NelsonSiegel, VasicekCurve +) diff --git a/quantflow/rates/interpolated.py b/quantflow/rates/interpolated.py new file mode 100644 index 00000000..137138ae --- /dev/null +++ b/quantflow/rates/interpolated.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import enum +from datetime import datetime, timedelta +from decimal import Decimal +from typing import Literal + +import numpy as np +from numpy.typing import ArrayLike +from pydantic import Field, PrivateAttr +from scipy.interpolate import PchipInterpolator +from scipy.optimize import Bounds +from typing_extensions import Annotated, Doc + +from quantflow.utils.dates import as_utc +from quantflow.utils.numbers import DecimalNumber +from quantflow.utils.types import FloatArray, FloatArrayLike, maybe_float + +from .calibration import YieldCurveCalibration +from .yield_curve import YieldCurve + +_YEAR = 365.0 * 86400.0 + + +class InterpolationType(enum.StrEnum): + """Interpolation method for the log discount factor""" + + LINEAR = enum.auto() + """Piecewise linear in the log discount factor (piecewise constant forwards).""" + + MONOTONE_CUBIC = enum.auto() + """Shape-preserving cubic Hermite spline (PCHIP, Fritsch-Carlson) that never + introduces a new local maximum or minimum between two nodes.""" + + +class InterpolatedYieldCurve(YieldCurve, arbitrary_types_allowed=True): + r"""Yield curve built by interpolating the log discount factor on a set of nodes. + + The curve is defined by continuously compounded zero rates $r_i$ at a set of + anchor dates. Times to maturity $\tau_i$ are measured from + [ref_date][..ref_date] on an ACT/365 basis, and the log discount factor at + each node is $g_i = -r_i \tau_i$. The curve interpolates $g(\tau) = \ln + D(\tau)$ between the nodes, which keeps the instantaneous forward rate + $f(\tau) = -g'(\tau)$ simple: piecewise constant for linear interpolation and + smooth for cubic interpolation. + + The node $\tau = 0$ with $g = 0$ (i.e. $D(0) = 1$) is added automatically. + Beyond the last node the instantaneous forward rate is held flat (constant + forward extrapolation). + """ + + curve_type: Literal["interpolated_yield_curve"] = "interpolated_yield_curve" + anchor_dates: list[datetime] = Field( + description="Maturity dates of the interpolation nodes, strictly after the " + "reference date and in increasing order" + ) + anchor_rates: list[DecimalNumber] = Field( + description="Continuously compounded zero rates at each anchor date " + "(0.05 means 5%), same length as anchor_dates" + ) + interpolation_type: InterpolationType = Field( + default=InterpolationType.LINEAR, + description="Interpolation method for the log discount factor: " + "LINEAR or MONOTONE_CUBIC (PCHIP)", + ) + + _ttm: FloatArray = PrivateAttr(default_factory=lambda: np.empty(0)) + _log_discount: FloatArray = PrivateAttr(default_factory=lambda: np.empty(0)) + + def model_post_init(self, context: object) -> None: + """Cache the times to maturity and log discount factors at the nodes.""" + if len(self.anchor_dates) != len(self.anchor_rates): + raise ValueError("anchor_dates and anchor_rates must have equal length") + if not self.anchor_dates: + raise ValueError("at least one anchor is required") + ttm = self._year_fractions() + if np.any(ttm <= 0): + raise ValueError("anchor_dates must be strictly after the reference date") + if np.any(np.diff(ttm) <= 0): + raise ValueError("anchor_dates must be strictly increasing") + rates = np.array([float(r) for r in self.anchor_rates], dtype=float) + self._ttm = ttm + self._log_discount = -rates * ttm + + def _year_fractions(self) -> FloatArray: + """Times to maturity in years from ref_date, ACT/365.""" + ref = as_utc(self.ref_date) + return np.array( + [(as_utc(m) - ref).total_seconds() / _YEAR for m in self.anchor_dates], + dtype=float, + ) + + def calibrator(self) -> InterpolatedYieldCurveCalibration: + """Return an [InterpolatedYieldCurveCalibration][ + ...InterpolatedYieldCurveCalibration] wrapping this curve.""" + return InterpolatedYieldCurveCalibration(yield_curve=self) + + def _nodes(self) -> tuple[FloatArray, FloatArray]: + """Node times and log discount factors, with the origin pinned to (0, 0).""" + t = np.concatenate([[0.0], self._ttm]) + g = np.concatenate([[0.0], self._log_discount]) + return t, g + + def instantaneous_forward_rate(self, ttm: FloatArrayLike) -> FloatArrayLike: + t, g = self._nodes() + tau = np.maximum(np.asarray(ttm, dtype=float), 0.0) + tmax = t[-1] + if self.interpolation_type is InterpolationType.LINEAR: + slope = np.diff(g) / np.diff(t) + idx = np.clip( + np.searchsorted(t, np.minimum(tau, tmax), side="right") - 1, + 0, + slope.size - 1, + ) + f = -slope[idx] + else: + dg = PchipInterpolator(t, g).derivative() + f = -dg(np.minimum(tau, tmax)) + return maybe_float(f) + + def discount_factor(self, ttm: FloatArrayLike) -> FloatArrayLike: + t, g = self._nodes() + tau = np.maximum(np.asarray(ttm, dtype=float), 0.0) + tmax = t[-1] + inside = np.minimum(tau, tmax) + if self.interpolation_type is InterpolationType.LINEAR: + slope = np.diff(g) / np.diff(t) + gi = np.interp(inside, t, g) + f_last = -slope[-1] + else: + pch = PchipInterpolator(t, g) + gi = pch(inside) + f_last = -float(pch.derivative()(tmax)) + # constant forward rate extrapolation beyond the last node + gi = np.where(tau > tmax, g[-1] - f_last * (tau - tmax), gi) + return maybe_float(np.exp(gi)) + + +class InterpolatedYieldCurveCalibration(YieldCurveCalibration[InterpolatedYieldCurve]): + """Calibration wrapper for an interpolated yield curve. + + The interpolated curve passes exactly through its nodes, so calibration is + direct: the anchor dates and rates are set from the input times to maturity + and continuously compounded rates. The free parameters are the anchor rates. + """ + + def get_params(self) -> FloatArray: + return np.array([float(r) for r in self.yield_curve.anchor_rates], dtype=float) + + def set_params(self, params: FloatArray) -> None: + curve = self.yield_curve + rates = np.asarray(params, dtype=float) + curve.anchor_rates = [Decimal(str(round(float(r), 10))) for r in rates] + curve._log_discount = -rates * curve._ttm + + def get_bounds(self) -> Bounds: + n = len(self.yield_curve.anchor_rates) + return Bounds(np.full(n, -np.inf), np.full(n, np.inf)) + + def calibrate( + self, + ttm: Annotated[ArrayLike, Doc("Times to maturity in years.")], + rates: Annotated[ + ArrayLike, Doc("Continuously compounded rates, same length as ttm.") + ], + ) -> InterpolatedYieldCurve: + """Set the curve nodes so it reprices the given rates exactly. + + Maturity dates are reconstructed from the times to maturity relative to + [ref_date][..ref_date] on an ACT/365 basis. + """ + ttm_ = np.asarray(ttm, dtype=float) + rates_ = np.asarray(rates, dtype=float) + order = np.argsort(ttm_) + curve = self.yield_curve + ref = curve.ref_date + curve.anchor_dates = [ + ref + timedelta(seconds=float(t) * _YEAR) for t in ttm_[order] + ] + curve.anchor_rates = [Decimal(str(round(float(r), 10))) for r in rates_[order]] + curve._ttm = ttm_[order] + curve._log_discount = -rates_[order] * ttm_[order] + return curve diff --git a/quantflow/utils/types.py b/quantflow/utils/types.py index 076a1932..c8da9fb7 100644 --- a/quantflow/utils/types.py +++ b/quantflow/utils/types.py @@ -4,7 +4,7 @@ import numpy as np import numpy.typing as npt import pandas as pd -from pydantic import PlainSerializer +from pydantic import BeforeValidator, PlainSerializer from typing_extensions import TypeAlias Number = Decimal @@ -14,6 +14,7 @@ Vector: TypeAlias = int | float | complex | np.ndarray | pd.Series FloatArray = Annotated[ npt.NDArray[np.floating[Any]], + BeforeValidator(lambda x: np.asarray(x, dtype=float)), PlainSerializer(lambda x: x.tolist(), return_type=list), ] IntArray = npt.NDArray[np.signedinteger[Any]] diff --git a/quantflow_tests/test_data.py b/quantflow_tests/test_data.py index 455236f8..952a6caa 100644 --- a/quantflow_tests/test_data.py +++ b/quantflow_tests/test_data.py @@ -27,7 +27,7 @@ def test_client(fmp: FMP) -> None: @pytest.mark.skipif(skip_fmp, reason="No FMP API key found") async def test_historical(fmp: FMP) -> None: - df = await fmp.prices("BTCUSD", frequency=fmp.freq.one_hour) + df = await fmp.prices("BTCUSD", frequency=fmp.freq.ONE_HOUR) assert df["close"] is not None diff --git a/quantflow_tests/test_disable_outliers.py b/quantflow_tests/test_disable_outliers.py index 726dac08..5fe6f169 100644 --- a/quantflow_tests/test_disable_outliers.py +++ b/quantflow_tests/test_disable_outliers.py @@ -31,7 +31,7 @@ def _make_option( strike: float, iv_mid: float, iv_spread_fraction: float, - option_type: OptionType = OptionType.call, + option_type: OptionType = OptionType.CALL, ) -> OptionPrices[DefaultVolSecurity]: iv_bid = iv_mid * (1 - iv_spread_fraction / 2) iv_ask = iv_mid * (1 + iv_spread_fraction / 2) @@ -47,7 +47,7 @@ def _make_option( forward=FORWARD, ttm=TTM, iv=iv_bid, - side=Side.bid, + side=Side.BID, converged=True, ) ask = OptionPrice( @@ -56,7 +56,7 @@ def _make_option( forward=FORWARD, ttm=TTM, iv=iv_ask, - side=Side.ask, + side=Side.ASK, converged=True, ) return OptionPrices(security=SECURITY, meta=meta, bid=bid, ask=ask) diff --git a/quantflow_tests/test_fmp_unit.py b/quantflow_tests/test_fmp_unit.py index 4f0200d6..da80064e 100644 --- a/quantflow_tests/test_fmp_unit.py +++ b/quantflow_tests/test_fmp_unit.py @@ -9,9 +9,9 @@ def test_freq_crate_and_join_and_params() -> None: - assert FMP.freq.crate(None) == FMP.freq.daily - assert FMP.freq.crate("1hour") == FMP.freq.one_hour - assert FMP.freq.crate("bad") == FMP.freq.daily + assert FMP.freq.crate(None) == FMP.freq.DAILY + assert FMP.freq.crate("1hour") == FMP.freq.ONE_HOUR + assert FMP.freq.crate("bad") == FMP.freq.DAILY fmp = FMP(key="k") assert fmp.join("AAPL", "MSFT") == "AAPL,MSFT" diff --git a/quantflow_tests/test_interpolated_curve.py b/quantflow_tests/test_interpolated_curve.py new file mode 100644 index 00000000..cb197978 --- /dev/null +++ b/quantflow_tests/test_interpolated_curve.py @@ -0,0 +1,227 @@ +"""Tests for the interpolated yield curve (log discount factor interpolation).""" + +from __future__ import annotations + +import math +from datetime import datetime, timedelta, timezone +from decimal import Decimal + +import numpy as np +import pytest +from pydantic import TypeAdapter, ValidationError + +from quantflow.rates import AnyYieldCurve, InterpolatedYieldCurve, InterpolationType + +_REF = datetime(2026, 6, 7, tzinfo=timezone.utc) +_TTM = np.array([0.25, 1.0, 2.0, 5.0, 10.0]) +_RATES = [Decimal(r) for r in ("0.02", "0.025", "0.03", "0.035", "0.04")] +_RATES_F = np.array([float(r) for r in _RATES]) +_YEAR = 365.0 * 86400.0 +_ADAPTER: TypeAdapter[AnyYieldCurve] = TypeAdapter(AnyYieldCurve) + + +def _dates(ttm: np.ndarray = _TTM) -> list[datetime]: + return [_REF + timedelta(seconds=float(t) * _YEAR) for t in ttm] + + +def _curve( + interpolation_type: InterpolationType = InterpolationType.MONOTONE_CUBIC, + rates: list[Decimal] = _RATES, +) -> InterpolatedYieldCurve: + return InterpolatedYieldCurve( + ref_date=_REF, + anchor_dates=_dates(), + anchor_rates=rates, + interpolation_type=interpolation_type, + ) + + +@pytest.fixture(params=list(InterpolationType)) +def interpolation_type(request: pytest.FixtureRequest) -> InterpolationType: + return request.param + + +# --------------------------------------------------------------------------- +# Construction and derived state +# --------------------------------------------------------------------------- + + +def test_default_interpolation_is_linear() -> None: + curve = InterpolatedYieldCurve( + ref_date=_REF, anchor_dates=_dates(), anchor_rates=_RATES + ) + assert curve.interpolation_type is InterpolationType.LINEAR + + +def test_anchor_rates_coerced_to_decimal() -> None: + curve = _curve() + assert all(isinstance(r, Decimal) for r in curve.anchor_rates) + + +def test_private_attrs_populated() -> None: + curve = _curve() + assert curve._ttm == pytest.approx(_TTM) + assert curve._log_discount == pytest.approx(-_RATES_F * _TTM) + + +# --------------------------------------------------------------------------- +# Discount factor and rates +# --------------------------------------------------------------------------- + + +def test_discount_factor_at_zero_is_one(interpolation_type: InterpolationType) -> None: + curve = _curve(interpolation_type) + assert float(curve.discount_factor(0.0)) == pytest.approx(1.0) + + +def test_reprices_nodes_exactly(interpolation_type: InterpolationType) -> None: + curve = _curve(interpolation_type) + fitted = curve.continuously_compounded_rate(_TTM) + assert np.asarray(fitted) == pytest.approx(_RATES_F) + + +def test_discount_factor_matches_node_rates( + interpolation_type: InterpolationType, +) -> None: + curve = _curve(interpolation_type) + for t, r in zip(_TTM, _RATES_F): + assert float(curve.discount_factor(t)) == pytest.approx(math.exp(-r * t)) + + +def test_scalar_input_returns_float(interpolation_type: InterpolationType) -> None: + curve = _curve(interpolation_type) + assert isinstance(curve.discount_factor(1.0), float) + assert isinstance(curve.instantaneous_forward_rate(1.0), float) + + +def test_inherited_rates_method_not_shadowed() -> None: + # the field is anchor_rates, so YieldCurve.rates() still works + curve = _curve() + semi = curve.rates(_TTM, frequency=2) + cont = curve.continuously_compounded_rate(_TTM) + assert np.all(np.asarray(semi) > 0) + # discrete compounding sits just below continuous for positive rates + assert np.all(np.asarray(semi) > np.asarray(cont) - 1e-9) + + +# --------------------------------------------------------------------------- +# Monotonicity and extrapolation +# --------------------------------------------------------------------------- + + +def test_discount_factor_monotone_decreasing( + interpolation_type: InterpolationType, +) -> None: + curve = _curve(interpolation_type) + grid = np.linspace(0.0, 15.0, 400) + df = np.asarray(curve.discount_factor(grid)) + assert np.all(np.diff(df) <= 1e-12) + + +def test_monotone_cubic_introduces_no_new_extrema() -> None: + # a non-monotone forward profile that would make a natural cubic overshoot + rates = [Decimal(r) for r in ("0.05", "0.01", "0.05", "0.01", "0.05")] + curve = _curve(InterpolationType.MONOTONE_CUBIC, rates=rates) + g = np.log(np.asarray(curve.discount_factor(np.linspace(0.0, 10.0, 500)))) + # log discount factor must stay within the envelope of the node values + rates_f = np.array([float(r) for r in rates]) + node_g = np.concatenate([[0.0], -rates_f * _TTM]) + assert g.min() >= node_g.min() - 1e-9 + assert g.max() <= node_g.max() + 1e-9 + + +def test_flat_forward_extrapolation(interpolation_type: InterpolationType) -> None: + curve = _curve(interpolation_type) + f_last = float(curve.instantaneous_forward_rate(10.0)) + assert float(curve.instantaneous_forward_rate(15.0)) == pytest.approx(f_last) + assert float(curve.instantaneous_forward_rate(30.0)) == pytest.approx(f_last) + # discount factor extends consistently with the flat forward + expected = float(curve.discount_factor(10.0)) * math.exp(-f_last * 5.0) + assert float(curve.discount_factor(15.0)) == pytest.approx(expected) + + +def test_linear_forward_is_piecewise_constant() -> None: + curve = _curve(InterpolationType.LINEAR) + # within a single segment (1y..2y) the forward rate is constant + f1 = float(curve.instantaneous_forward_rate(1.2)) + f2 = float(curve.instantaneous_forward_rate(1.8)) + assert f1 == pytest.approx(f2) + + +def test_forward_rate_consistent_with_discount_factor( + interpolation_type: InterpolationType, +) -> None: + # f(t) = -d/dt ln D(t): check against a central finite difference + curve = _curve(interpolation_type) + t, h = 3.0, 1e-5 + g_plus = math.log(float(curve.discount_factor(t + h))) + g_minus = math.log(float(curve.discount_factor(t - h))) + numerical = -(g_plus - g_minus) / (2 * h) + assert float(curve.instantaneous_forward_rate(t)) == pytest.approx( + numerical, rel=1e-4 + ) + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- + + +def test_json_round_trip_via_union(interpolation_type: InterpolationType) -> None: + curve = _curve(interpolation_type) + restored = _ADAPTER.validate_json(_ADAPTER.dump_json(curve)) + assert type(restored) is InterpolatedYieldCurve + assert restored.interpolation_type is interpolation_type + assert restored._ttm == pytest.approx(_TTM) + assert np.asarray(restored.discount_factor(_TTM)) == pytest.approx( + np.asarray(curve.discount_factor(_TTM)) + ) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def test_length_mismatch_raises() -> None: + with pytest.raises((ValidationError, ValueError)): + InterpolatedYieldCurve( + ref_date=_REF, anchor_dates=_dates()[:2], anchor_rates=_RATES + ) + + +def test_non_increasing_dates_raise() -> None: + dates = _dates() + dates[1], dates[2] = dates[2], dates[1] + with pytest.raises((ValidationError, ValueError)): + InterpolatedYieldCurve(ref_date=_REF, anchor_dates=dates, anchor_rates=_RATES) + + +def test_anchor_before_ref_date_raises() -> None: + with pytest.raises((ValidationError, ValueError)): + InterpolatedYieldCurve( + ref_date=_REF, anchor_dates=[_REF], anchor_rates=[Decimal("0.02")] + ) + + +# --------------------------------------------------------------------------- +# Calibration +# --------------------------------------------------------------------------- + + +def test_calibrate_from_ttm_reprices_exactly() -> None: + target = _RATES_F * 1.1 + curve = _curve().calibrator().calibrate(_TTM, target) + assert np.asarray(curve.continuously_compounded_rate(_TTM)) == pytest.approx(target) + assert all(isinstance(r, Decimal) for r in curve.anchor_rates) + + +def test_set_params_updates_log_discount() -> None: + curve = _curve() + calibrator = curve.calibrator() + new_rates = _RATES_F * 0.5 + calibrator.set_params(new_rates) + assert curve._log_discount == pytest.approx(-new_rates * _TTM) + assert np.asarray(curve.continuously_compounded_rate(_TTM)) == pytest.approx( + new_rates + ) diff --git a/quantflow_tests/test_non_inverse_surface.py b/quantflow_tests/test_non_inverse_surface.py index 16cd93a9..7c56788a 100644 --- a/quantflow_tests/test_non_inverse_surface.py +++ b/quantflow_tests/test_non_inverse_surface.py @@ -47,8 +47,8 @@ def _build_loader(ttm: float) -> VolSurfaceLoader: ) for strike in STRIKES: for option_type, call_put in ( - (OptionType.call, 1), - (OptionType.put, -1), + (OptionType.CALL, 1), + (OptionType.PUT, -1), ): mid = _black_mid_usd(strike, call_put, ttm) loader.add_option( diff --git a/quantflow_tests/test_options.py b/quantflow_tests/test_options.py index aa209b2d..a18c811d 100644 --- a/quantflow_tests/test_options.py +++ b/quantflow_tests/test_options.py @@ -38,7 +38,7 @@ def _make_option( forward: float, *, price: Decimal = Decimal(0), - option_type: OptionType = OptionType.call, + option_type: OptionType = OptionType.CALL, ref_date: datetime | None = None, maturity: datetime | None = None, ) -> OptionPrice: @@ -200,7 +200,7 @@ def test_call_put_parity(): option = _make_option(100, 100).calculate_price() assert option.log_strike == 0 assert option.price == option.call_price - option2 = _make_option(100, 100, option_type=OptionType.put).calculate_price() + option2 = _make_option(100, 100, option_type=OptionType.PUT).calculate_price() assert option2.price == option2.put_price assert option2.price == option.put_price assert option2.call_price == option.price @@ -210,7 +210,7 @@ def test_call_put_parity_otm(): option = _make_option(105, 100).calculate_price() assert option.log_strike > 0 assert option.price == option.call_price - option2 = _make_option(105, 100, option_type=OptionType.put).calculate_price() + option2 = _make_option(105, 100, option_type=OptionType.PUT).calculate_price() assert option2.price == option2.put_price assert option2.price == pytest.approx(option.put_price) assert option2.call_price == pytest.approx(option.price) diff --git a/quantflow_tests/test_options_pricer.py b/quantflow_tests/test_options_pricer.py index c48509b3..3e0365cc 100644 --- a/quantflow_tests/test_options_pricer.py +++ b/quantflow_tests/test_options_pricer.py @@ -27,7 +27,7 @@ def test_plot_surface(pricer: OptionPricer): def test_price_call(pricer: OptionPricer): price = pricer.price( - option_type=OptionType.call, + option_type=OptionType.CALL, strike=100, forward=100, ttm=1.0, @@ -45,7 +45,7 @@ def test_wiener_matches_black(strike: int, forward: int) -> None: sigma = 0.3 pricer = OptionPricer(model=WienerProcess(sigma=sigma)) price = pricer.price( - option_type=OptionType.call, strike=strike, forward=forward, ttm=1.0 + option_type=OptionType.CALL, strike=strike, forward=forward, ttm=1.0 ) black = price.black assert float(black.iv) == pytest.approx(sigma, rel=1e-3) @@ -62,10 +62,10 @@ def test_put_call_parity_across_strikes(strike: int, forward: int) -> None: """ pricer = OptionPricer(model=Heston.create(vol=0.2, kappa=2.0, sigma=0.5, rho=-0.5)) call = pricer.price( - option_type=OptionType.call, strike=strike, forward=forward, ttm=0.5 + option_type=OptionType.CALL, strike=strike, forward=forward, ttm=0.5 ) put = pricer.price( - option_type=OptionType.put, strike=strike, forward=forward, ttm=0.5 + option_type=OptionType.PUT, strike=strike, forward=forward, ttm=0.5 ) assert call.price - put.price == pytest.approx(1.0 - strike / forward, abs=1e-9) assert call.delta - put.delta == pytest.approx(1.0, abs=1e-9) @@ -76,13 +76,13 @@ def test_put_call_parity_across_strikes(strike: int, forward: int) -> None: "option_type,strike,forward,expected", [ # calls: payoff max(F - K, 0) / F = max(0, 1 - K/F) - (OptionType.call, 80, 100, 0.2), # ITM - (OptionType.call, 100, 100, 0.0), # ATM - (OptionType.call, 120, 100, 0.0), # OTM + (OptionType.CALL, 80, 100, 0.2), # ITM + (OptionType.CALL, 100, 100, 0.0), # ATM + (OptionType.CALL, 120, 100, 0.0), # OTM # puts: payoff max(K - F, 0) / F = max(0, K/F - 1) - (OptionType.put, 80, 100, 0.0), # OTM - (OptionType.put, 100, 100, 0.0), # ATM - (OptionType.put, 120, 100, 0.2), # ITM + (OptionType.PUT, 80, 100, 0.0), # OTM + (OptionType.PUT, 100, 100, 0.0), # ATM + (OptionType.PUT, 120, 100, 0.2), # ITM ], ) def test_intrinsic_value( @@ -100,7 +100,7 @@ def test_price_in_quote_scales_with_forward() -> None: """`price_in_quote` is the forward-space price multiplied by the forward.""" pricer = OptionPricer(model=Heston.create(vol=0.2, kappa=2.0, sigma=0.5, rho=-0.5)) price = pricer.price( - option_type=OptionType.call, strike=5500, forward=5000, ttm=0.5 + option_type=OptionType.CALL, strike=5500, forward=5000, ttm=0.5 ) assert price.price_in_quote == pytest.approx(price.price * 5000.0, abs=1e-9) @@ -110,9 +110,9 @@ def test_as_option_type_roundtrip(strike: int, forward: int) -> None: """`call.as_option_type(put).as_option_type(call)` recovers the original.""" pricer = OptionPricer(model=Heston.create(vol=0.2, kappa=2.0, sigma=0.5, rho=-0.5)) call = pricer.price( - option_type=OptionType.call, strike=strike, forward=forward, ttm=0.5 + option_type=OptionType.CALL, strike=strike, forward=forward, ttm=0.5 ) - roundtrip = call.as_option_type(OptionType.put).as_option_type(OptionType.call) + roundtrip = call.as_option_type(OptionType.PUT).as_option_type(OptionType.CALL) assert roundtrip.price == pytest.approx(call.price, abs=1e-12) assert roundtrip.delta == pytest.approx(call.delta, abs=1e-12) assert roundtrip.gamma == pytest.approx(call.gamma, abs=1e-12)