diff --git a/src/provider_simenv/agents/farmer.py b/src/provider_simenv/agents/farmer.py index 63c8196..ac6fe1b 100644 --- a/src/provider_simenv/agents/farmer.py +++ b/src/provider_simenv/agents/farmer.py @@ -168,12 +168,12 @@ def _step_bra(self): raises per-unit cost, which raises unit_price automatically. """ env = self.model.environment - farm_capacity = env.get_effective_value("farm_capacity_bra") + farm_capacity = env.get_effective_value("brazil_farms", "supply") self.quantity_available = self.base_yield * farm_capacity if self.quantity_available > 0: # fertilizer price factor raises effective fixed costs this step - fertilizer_factor = env.get_effective_value("fertilizer_price_factor") + fertilizer_factor = env.get_effective_value("fertilizer_supply", "price") effective_costs = self.fixed_costs * fertilizer_factor self.unit_price = (effective_costs / self.quantity_available) * (1.0 + self.margin) else: @@ -220,7 +220,7 @@ def _step_arg(self): farm_capacity_arg allows ARG-specific shocks to be modelled independently. Defaults to 1.0 = always unshocked. """ - farm_capacity = self.model.environment.get_effective_value("farm_capacity_arg") + farm_capacity = self.model.environment.get_effective_value("argentina_farms", "supply") self.quantity_available = self.base_yield * farm_capacity if self.quantity_available > 0: diff --git a/src/provider_simenv/agents/process.py b/src/provider_simenv/agents/process.py index d29b53a..102d816 100644 --- a/src/provider_simenv/agents/process.py +++ b/src/provider_simenv/agents/process.py @@ -71,12 +71,12 @@ def step(self): # Shared helper: receive from upstream list, convert, compute price # ------------------------------------------------------------------ - def _process(self, upstream_list, peer_list, scenario_param: str = ""): + def _process(self, upstream_list, peer_list, shock_key: tuple[str, str] | None = None): """ Pull an equal share of upstream output, apply conversion_ratio, and compute unit_price accounting for yield loss. - scenario_param: scenario param whose effective value scales output. + shock_key: PDL (entity, field) whose effective value scales output. Models indirect capacity reduction from soja shortage. For every 1 unit of output, (1 / conversion_ratio) input units @@ -95,7 +95,7 @@ def _process(self, upstream_list, peer_list, scenario_param: str = ""): self.unit_price = 0.0 return - effective_factor = self.model.environment.get_effective_value(scenario_param) if scenario_param else 1.0 + effective_factor = self.model.environment.get_effective_value(*shock_key) if shock_key else 1.0 total_input = sum(a.quantity_available for a in active_upstream) @@ -135,7 +135,7 @@ def _step_processor(self): self._process( upstream_list=combined_eu, peer_list=self.model.processors, - scenario_param="oil_mill_capacity", + shock_key=("eu_oil_mills", "supply"), ) def _step_feed_manufacturer(self): @@ -143,5 +143,5 @@ def _step_feed_manufacturer(self): self._process( upstream_list=self.model.processors, peer_list=self.model.feed_manufacturers, - scenario_param="feed_mill_capacity", + shock_key=("feed_mills", "supply"), ) diff --git a/src/provider_simenv/agents/transport.py b/src/provider_simenv/agents/transport.py index 568f51b..9d81795 100644 --- a/src/provider_simenv/agents/transport.py +++ b/src/provider_simenv/agents/transport.py @@ -109,12 +109,12 @@ def step(self): # cap at capacity, compute all-in unit_price (commodity + freight) # ------------------------------------------------------------------ - def _move(self, upstream, scenario_param: str = ""): + def _move(self, upstream, shock_key: tuple[str, str] | None = None): """ Pull an equal share of upstream output, ca at own capacity, and compute the all-in price passed to the next chain node. - scenario_param: scenario param whose effective value scales this agent's capacity, via env.get_effective_value() + shock_key: PDL (entity, field) whoe effective value scales this agent's capacity """ margin = self.scenario.margin_transport @@ -135,7 +135,7 @@ def _move(self, upstream, scenario_param: str = ""): volume_in = total_volume / n_self env = self.model.environment - effective_factor = env.get_effective_value(scenario_param) if scenario_param else 1.0 + effective_factor = env.get_effective_value(*shock_key) if shock_key else 1.0 # effective capacity after applying port capacity shock effective_capacity = self.capacity * effective_factor @@ -151,7 +151,7 @@ def _move(self, upstream, scenario_param: str = ""): # price = commodity price + freight fee per unit # energy price factor raises transport operation costs if self.quantity_available > 0: - energy_factor = env.get_effective_value("energy_price_factor") + energy_factor = env.get_effective_value("gas_supply", "price") effective_costs = self.fixed_costs * energy_factor freight_fee = (effective_costs / self.quantity_available) * (1.0 + margin) self.unit_price = upstream_price + freight_fee @@ -159,13 +159,13 @@ def _move(self, upstream, scenario_param: str = ""): self.unit_price = 0.0 - def _move_split(self, upstream_list, share: float, scenario_param: str = "", exclude_arg=False, exclude_usa=False): + def _move_split(self, upstream_list, share: float, shock_key: tuple[str, str] | None = None, exclude_arg=False, exclude_usa=False): """ Like _move, but routes only share fraction of total upstream volume through this port. Used to split wholesaler output between Santos and Paranagua. :param share: fraction of total wholesaler output for this port (e.g. 0.7 for Santos, 0.3 for Paranagua). - :param scenario_param: scenario param whose effective value scales this agent's capacity. + :param shock_key: PDL (entity, field) whoe effective value scales this agent's capacity """ margin = self.scenario.margin_transport active_upstream = upstream_list.filter(lambda a: a.active) @@ -192,7 +192,7 @@ def _move_split(self, upstream_list, share: float, scenario_param: str = "", exc volume_in = (routable_volume * share) / n_self env = self.model.environment - effective_factor = env.get_effective_value(scenario_param) if scenario_param else 1.0 + effective_factor = env.get_effective_value(*shock_key) if shock_key else 1.0 effective_capacity = self.capacity * effective_factor self.quantity_available = min(volume_in, effective_capacity) @@ -204,7 +204,7 @@ def _move_split(self, upstream_list, share: float, scenario_param: str = "", exc upstream_price = (total_value / total_volume) if total_volume > 0 else 0.0 if self.quantity_available > 0: - energy_factor = env.get_effective_value("energy_price_factor") + energy_factor = env.get_effective_value("gas_supply", "price") effective_costs = self.fixed_costs * energy_factor freight_fee = (effective_costs / self.quantity_available) * (1.0 + margin) self.unit_price = upstream_price + freight_fee @@ -244,7 +244,7 @@ def _step_sa_santos(self): self._move_split( self.model.wholesalers, share=self.scenario.santos_share, - scenario_param="port_capacity_santos", + shock_key=("santos_port", "supply"), exclude_arg=True, exclude_usa=True, ) @@ -257,7 +257,7 @@ def _step_sa_paranagua(self): self._move_split( self.model.wholesalers, share=1.0 - self.scenario.santos_share, - scenario_param="port_capacity_paranagua", + shock_key=("paranagua_port", "supply"), exclude_arg=True, exclude_usa=True, ) @@ -315,7 +315,7 @@ def _step_sea_arg(self): upstream_price = total_value / total_arg if self.quantity_available > 0: - energy_factor = self.model.environment.get_effective_value("energy_price_factor") + energy_factor = self.model.environment.get_effective_value("gas_supply", "price") effective_costs = self.fixed_costs * energy_factor freight_fee = (effective_costs / self.quantity_available) * (1.0 + margin) self.unit_price = upstream_price + freight_fee @@ -364,7 +364,7 @@ def _step_sea_usa(self): upstream_price = total_value / total_usa if self.quantity_available > 0: - energy_factor = self.model.environment.get_effective_value("energy_price_factor") + energy_factor = self.model.environment.get_effective_value("gas_supply", "price") effective_costs = self.fixed_costs * energy_factor freight_fee = (effective_costs / self.quantity_available) * (1.0 + margin) self.unit_price = upstream_price + freight_fee @@ -382,7 +382,7 @@ def _step_eu_rtm(self): + self.model.sea_lane_arg.filter(lambda a: a.active) + self.model.sea_lane_usa.filter(lambda a: a.active) ) - self._move(combined, scenario_param="port_capacity_rotterdam") + self._move(combined, shock_key=("rotterdam_port", "supply")) def _step_eu_ham(self): @@ -393,6 +393,6 @@ def _step_eu_ham(self): """ self._move( self.model.sea_lane_paranagua, - scenario_param="port_capacity_hamburg", + shock_key=("hamburg_port", "supply"), ) diff --git a/src/provider_simenv/data/input/SimulatorScenarios.csv b/src/provider_simenv/data/input/SimulatorScenarios.csv index 0c2d17d..63d3e12 100644 --- a/src/provider_simenv/data/input/SimulatorScenarios.csv +++ b/src/provider_simenv/data/input/SimulatorScenarios.csv @@ -1,4 +1,3 @@ -id,run_num,period_num,n_bra_farmers,n_arg_farmers,n_usa_farmers,n_wholesalers,n_transport_sa_santos,n_transport_sa_paranagua,n_sea_lane_santos,n_sea_lane_paranagua,n_sea_lane_arg,n_sea_lane_usa,n_transport_eu_rtm,n_transport_eu_ham,n_processors,n_feed_manufacturers,n_feed_traders,n_eu_farmers,farm_capacity_bra,farm_capacity_arg,port_capacity_santos,port_capacity_paranagua,port_capacity_rotterdam,port_capacity_hamburg,santos_share,fertilizer_price_factor,energy_price_factor,oil_mill_capacity,feed_mill_capacity,shock_ramp_steps,shock_onset_farm_bra,shock_end_farm_bra,shock_onset_farm_arg,shock_end_farm_arg,shock_onset_port_santos,shock_end_port_santos,shock_onset_port_paranagua,shock_end_port_paranagua,shock_onset_port_rotterdam,shock_end_port_rotterdam,shock_onset_port_hamburg,shock_end_port_hamburg,shock_onset_fertilizer,shock_end_fertilizer,shock_onset_energy,shock_end_energy,shock_onset_oil_mill,shock_end_oil_mill,shock_onset_feed_mill,shock_end_feed_mill,farm_size_sigma_bra,farm_size_sigma_eu,farm_size_seed,wholesaler_storage_capacity -0,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,1.0,1.0,1.0,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0.0,0.0,42,2857.0 -1,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.6,1.0,1.0,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0.4,0.4,42,2857.0 -2,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.6,1.0,0.5,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0.4,0.4,42,2857.0 +id,run_num,period_num,n_bra_farmers,n_arg_farmers,n_usa_farmers,n_wholesalers,n_transport_sa_santos,n_transport_sa_paranagua,n_sea_lane_santos,n_sea_lane_paranagua,n_sea_lane_arg,n_sea_lane_usa,n_transport_eu_rtm,n_transport_eu_ham,n_processors,n_feed_manufacturers,n_feed_traders,n_eu_farmers,santos_share,shock_ramp_steps,farm_size_sigma_bra,farm_size_sigma_eu,farm_size_seed,wholesaler_storage_capacity +0,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.7,0,0.0,0.0,42,2857.0 +1,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.7,0,0.0,0.0,42,2857.0 diff --git a/src/provider_simenv/data/input/SimulatorScenarios_template.csv b/src/provider_simenv/data/input/SimulatorScenarios_template.csv index 7227ac1..319daf4 100644 --- a/src/provider_simenv/data/input/SimulatorScenarios_template.csv +++ b/src/provider_simenv/data/input/SimulatorScenarios_template.csv @@ -1,4 +1,4 @@ -id,run_num,period_num,n_bra_farmers,n_arg_farmers,n_usa_farmers,n_wholesalers,n_transport_sa_santos,n_transport_sa_paranagua,n_sea_lane_santos,n_sea_lane_paranagua,n_sea_lane_arg,n_sea_lane_usa,n_transport_eu_rtm,n_transport_eu_ham,n_processors,n_feed_manufacturers,n_feed_traders,n_eu_farmers,farm_capacity_bra,farm_capacity_arg,port_capacity_santos,port_capacity_paranagua,port_capacity_rotterdam,port_capacity_hamburg,santos_share,fertilizer_price_factor,energy_price_factor,oil_mill_capacity,feed_mill_capacity,shock_ramp_steps,shock_onset_farm_bra,shock_end_farm_bra,shock_onset_farm_arg,shock_end_farm_arg,shock_onset_port_santos,shock_end_port_santos,shock_onset_port_paranagua,shock_end_port_paranagua,shock_onset_port_rotterdam,shock_end_port_rotterdam,shock_onset_port_hamburg,shock_end_port_hamburg,shock_onset_fertilizer,shock_end_fertilizer,shock_onset_energy,shock_end_energy,shock_onset_oil_mill,shock_end_oil_mill,shock_onset_feed_mill,shock_end_feed_mill,farm_size_sigma_bra,farm_size_sigma_eu,farm_size_seed,wholesaler_storage_capacity -0,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,1.0,1.0,1.0,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0.0,0.0,42,2857.0 -1,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.6,1.0,1.0,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0.4,0.4,42,2857.0 -2,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.6,1.0,0.5,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0,365,0.4,0.4,42,2857.0 \ No newline at end of file +id,run_num,period_num,n_bra_farmers,n_arg_farmers,n_usa_farmers,n_wholesalers,n_transport_sa_santos,n_transport_sa_paranagua,n_sea_lane_santos,n_sea_lane_paranagua,n_sea_lane_arg,n_sea_lane_usa,n_transport_eu_rtm,n_transport_eu_ham,n_processors,n_feed_manufacturers,n_feed_traders,n_eu_farmers,santos_share,shock_ramp_steps,farm_size_sigma_bra,farm_size_sigma_eu,farm_size_seed,wholesaler_storage_capacity +0,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.7,0,0.0,0.0,42,2857.0 +1,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.7,0,0.4,0.4,42,2857.0 +2,1,365,10,5,8,3,1,1,1,1,1,1,1,1,3,3,3,10,0.7,0,0.4,0.4,42,2857.0 \ No newline at end of file diff --git a/src/provider_simenv/environment.py b/src/provider_simenv/environment.py index ae3211e..d7c957c 100644 --- a/src/provider_simenv/environment.py +++ b/src/provider_simenv/environment.py @@ -16,23 +16,9 @@ from typing import TYPE_CHECKING from Melodie import Environment -if TYPE_CHECKING: - from event_tracker import EventTracker - -# maps scenario param name -> (onset_field, end_field) on SupplyChainScenario -_PARAM_TIMING_FIELDS: list[tuple[str, str, str]] = [ - ("farm_capacity_bra", "shock_onset_farm_bra", "shock_end_farm_bra"), - ("farm_capacity_arg", "shock_onset_farm_arg", "shock_end_farm_arg"), - ("port_capacity_santos", "shock_onset_port_santos", "shock_end_port_santos"), - ("port_capacity_paranagua", "shock_onset_port_paranagua", "shock_end_port_paranagua"), - ("port_capacity_rotterdam", "shock_onset_port_rotterdam", "shock_end_port_rotterdam"), - ("port_capacity_hamburg", "shock_onset_port_hamburg", "shock_end_port_hamburg"), - ("fertilizer_price_factor", "shock_onset_fertilizer", "shock_end_fertilizer"), - ("energy_price_factor", "shock_onset_energy", "shock_end_energy"), - ("oil_mill_capacity", "shock_onset_oil_mill", "shock_end_oil_mill"), - ("feed_mill_capacity", "shock_onset_feed_mill", "shock_end_feed_mill"), -] +from .shock_registry import DROUGHT_KEY +from .event_tracker import EventTracker class SupplyChainEnvironment(Environment): @@ -47,7 +33,7 @@ class SupplyChainEnvironment(Environment): feed_price: float = 0.0 # global shock intensity - # Agents should call get_shock_scale(param) instead of reading this directly. + # Agents should call get_shock_scale(entity, field) instead of reading this directly. shock_scale: float = 0.0 # drought severity this step @@ -78,10 +64,8 @@ def setup(self): self.transport_utilisation = 0.0 self.current_step = 0 - # per-parameter shock activation scale - self.shock_scales: dict[str, float] = { - param: 0.0 for param, _, _ in _PARAM_TIMING_FIELDS - } + # per (entity, field) shock activation scale. + self.shock_scales: dict[tuple[str, str], float] = {} def update_shock_scales(self, period: int): @@ -94,43 +78,42 @@ def update_shock_scales(self, period: int): """ if self._tracker is not None: self._tracker.step(period) - for param, _, _ in _PARAM_TIMING_FIELDS: - self.shock_scales[param] = self._tracker.get_shock_scale(param) + # seed the key once + if not self.shock_scales: + self.shock_scales = {key: 0.0 for key in self._tracker.known_keys()} + for key in self.shock_scales: + self.shock_scales[key] = self._tracker.get_shock_scale(*key) else: - for param, onset_field, end_field in _PARAM_TIMING_FIELDS: - onset = getattr(self.scenario, onset_field) - end = getattr(self.scenario, end_field) - value = getattr(self.scenario, param) - has_shock = value != 1.0 - self.shock_scales[param] = (1.0 if has_shock and onset <= period < end else 0.0) + for key in self.shock_scales: + self.shock_scales[key] = 0.0 self.shock_scale = max(self.shock_scales.values(), default=0.0) - # drought severity: use racker value if available - bra_scale = self.shock_scales.get("farm_capacity_bra", 0.0) - bra_value = self.get_effective_value("farm_capacity_bra") - self.drought_severity = ( - bra_scale * (1.0 - bra_value) - ) + # Drought severity is defined as brazil_farms supply degradation (DROUGHT_KEY) + bra_scale = self.shock_scales.get(DROUGHT_KEY, 0.0) + bra_value = self.get_effective_value(*DROUGHT_KEY) + self.drought_severity = (bra_scale * (1.0 - bra_value)) + + - def get_shock_scale(self, param: str) -> float: + def get_shock_scale(self, entity: str, field: str) -> float: """ Return the current shock actibation scale for a scenario parameter. """ - return self.shock_scales.get(param, 0.0) + return self.shock_scales.get((entity, field), 0.0) - def get_effective_value(self, param: str) -> float: + def get_effective_value(self, entity: str, field: str) -> float: """ Return the effective value for this step. Tracker mode: aggregated from currently active events only. - Static mode: reads fixed value from the scenario. + No tracker (baseline / non-PDL): unshocked, always 1.0 """ if self._tracker is not None: - return self._tracker.get_param_value(param) - return getattr(self.scenario, param, 1.0) + return self._tracker.get_param_value(entity, field) + return 1.0 def step(self): diff --git a/src/provider_simenv/event_tracker.py b/src/provider_simenv/event_tracker.py index 06d854a..2d456ef 100644 --- a/src/provider_simenv/event_tracker.py +++ b/src/provider_simenv/event_tracker.py @@ -9,19 +9,22 @@ from __future__ import annotations from dataclasses import dataclass +from .shock_registry import aggregate + @dataclass(frozen=True) class EventDef: """ - Single PDL event definition, pre-mapped to simenv params. - Events without a simenv param (param: None) are also tracked. - Other events may depend on them. + Single PDL event definition. + Keys are the PDL (entity, impact_field) pairs, passed through unchanged. + No translation to simenv param names. + Events with no supply(price impact (impacts={}) are still tracked, + so other events can depend on them via conditions. """ id: str - param: str | None # simenv param name, None if unmapped - value: float | None # converted impact value - duration: int # duration (days) - condition: str # condition string, "" = unconditional - impact_field: str # supply || price + entity: str | None + impacts: dict[str, float] + duration: int + condition: str @dataclass(frozen=True) @@ -42,16 +45,6 @@ class ActiveEvent: event_def: EventDef -# --- Aggregation constants --- -_CAPACITY_PARAMS = { - "farm_capacity_bra", "farm_capacity_arg", - "port_capacity_santos", "port_capacity_paranagua", - "port_capacity_rotterdam", "port_capacity_hamburg", - "oil_mill_capacity", "feed_mill_capacity", -} -_PRICE_PARAMS = {"energy_price_factor", "fertilizer_price_factor"} - - # --- Event tracker --- class EventTracker: """ @@ -71,11 +64,10 @@ def __init__(self, events: list[dict], timeline: list[dict]) -> None: for e in events: self._events[e["id"]] = EventDef( id=e["id"], - param=e.get("param"), - value=e.get("value"), + entity=e.get("entity"), + impacts=e.get("impacts") or {}, duration=e.get("duration", 0), - condition=e.get("condition", ""), - impact_field=e.get("impact_field", "supply"), + condition=e.get("condition"), ) # timeline sorted by day @@ -91,8 +83,8 @@ def __init__(self, events: list[dict], timeline: list[dict]) -> None: self._current_day: int = -1 # derived shock state (re-calc every step) - self._shock_scales: dict[str, float] = {} - self._param_values: dict[str, float] = {} + self._shock_scales: dict[tuple[str, str], float] = {} + self._param_values: dict[tuple[str, str], float] = {} def step(self, day: int) -> None: @@ -155,19 +147,33 @@ def step(self, day: int) -> None: self._recompute_scales() - def get_shock_scale(self, param: str) -> float: + def get_shock_scale(self, entity: str, field: str) -> float: """ 1.0 if param has any active shock, 0.0 otherwise. """ - return self._shock_scales.get(param, 0.0) + return self._shock_scales.get((entity, field), 0.0) - def get_param_value(self, param: str) -> float: + def get_param_value(self, entity: str, field: str) -> float: """ Aggregated shock value from active events targeting this param. Returns 1.0 if no active events targeting this param. """ - return self._param_values.get(param, 1.0) + return self._param_values.get((entity, field), 1.0) + + + def known_keys(self) -> set[tuple[str, str]]: + """ + Every (entity, field) key the loaded events can emit. Lets the environment + seed/iterate shock state without a static param list. + """ + keys: set[tuple[str, str]] = set() + for edef in self._events.values(): + if edef.entity is None: + continue + for field in edef.impacts: + keys.add((edef.entity, field)) + return keys def is_event_active(self, event_id: str) -> int: @@ -240,17 +246,14 @@ def _recompute_scales(self) -> None: for active in self._active.values(): edef = active.event_def - if edef.param is not None and edef.value is not None: - candidates.setdefault(edef.param, []).append(edef.value) + if edef.entity is None: + continue + for field, value in edef.impacts.items(): + candidates.setdefault((edef.entity, field), []).append(value) self._shock_scales.clear() self._param_values.clear() - for param, values in candidates.items(): - self._shock_scales[param] = 1.0 - if param in _CAPACITY_PARAMS: - self._param_values[param] = min(values) - elif param in _PRICE_PARAMS: - self._param_values[param] = max(values) - else: - self._param_values[param] = values[0] + for (entity, field), values in candidates.items(): + self._shock_scales[(entity, field)] = 1.0 + self._param_values[(entity, field)] = aggregate(field, values) diff --git a/src/provider_simenv/main.py b/src/provider_simenv/main.py index 522028b..1e0bff9 100644 --- a/src/provider_simenv/main.py +++ b/src/provider_simenv/main.py @@ -14,24 +14,11 @@ import pandas as pd from Melodie import Config, Simulator -from scipy.constants import value from provider_simenv.model import SupplyChainModel from provider_simenv.scenario import SupplyChainScenario from provider_simenv.pdl_loader import PDLLoader -_PDL_TIMING_COLUMNS = { - "farm_capacity_bra": ("shock_onset_farm_bra", "shock_end_farm_bra"), - "farm_capacity_arg": ("shock_onset_farm_arg", "shock_end_farm_arg"), - "port_capacity_santos": ("shock_onset_port_santos", "shock_end_port_santos"), - "port_capacity_paranagua": ("shock_onset_port_paranagua", "shock_end_port_paranagua"), - "port_capacity_rotterdam": ("shock_onset_port_rotterdam", "shock_end_port_rotterdam"), - "port_capacity_hamburg": ("shock_onset_port_hamburg", "shock_end_port_hamburg"), - "fertilizer_price_factor": ("shock_onset_fertilizer", "shock_end_fertilizer"), - "energy_price_factor": ("shock_onset_energy", "shock_end_energy"), - "oil_mill_capacity": ("shock_onset_oil_mill", "shock_end_oil_mill"), - "feed_mill_capacity": ("shock_onset_feed_mill", "shock_end_feed_mill"), -} # -------------------- # Helpers # -------------------- @@ -124,67 +111,38 @@ def csv_to_sqlite(output_dir: str, db_name: str = "provider-simenv.sqlite") -> N if os.path.exists(template_path): shutil.copy2(template_path, csv_path) - # PDL Injection: update shock columns in SimulatorScenarios.csv + # PDL Injection: a PDL run adds one shock scenario row (id=1) to SimulatorScenario.csv + # Shock values and timing are derived at runtime by the EventTracker from the PDL itself if args.pdl: loader = PDLLoader(args.pdl) - schedule = loader.to_cascade_schedule(args.cascade) - all_overrides = loader.to_scenario_overrides() - overrides = { - param: value - for param, value in all_overrides.items() - if param in schedule - } - - print(f"\n[pdl_loader] Scenario : {loader.label}") - print(f"[pdl_loader] Source : {args.pdl}") - print(f"[pdl_loader] Cascade: {args.cascade or 'first cascade in file'}") - print("[pdl_loader] Overrides applied to SimulatorScenarios.csv (id > 0):") - for col, val in overrides.items(): - print(f"{col} = {val}") + + print(f"\n[pdl_loader] Scenario: {loader.label}") + print(f"\n[pdl_loader] Source: {args.pdl}") + print(f"\n[pdl_loader] Cascade: {args.cascade or 'first cascade in file'}") df = pd.read_csv(csv_path) # keep only the baseline row (id=0) baseline = df[df["id"] == 0].copy() - # build exactly one PDL scenario row from the baseline + # build exactly one PDL scenario row from the baseline (shocks injected at runtime) pdl_row = baseline.iloc[0].copy() pdl_row["id"] = 1 - # apply shock value override - for col, val in overrides.items(): - if col in df.columns: - pdl_row[col] = val - else: - print(f"[pdl_loader] WARNING: column {col} not found, skipping") - - # apply cascade timing - print("[pdl_loader] Cascade timing applied to PDL scenario row (id=1):") - for param, timing in schedule.items(): - fields = _PDL_TIMING_COLUMNS.get(param) - if fields is None: - continue - onset_col, end_col = fields - pdl_row[onset_col] = timing["onset"] - pdl_row[end_col] = timing["end"] - print( - f" {param}: " - f"{onset_col}={timing['onset']}, {end_col}={timing['end']}" - ) - df = pd.concat([baseline, pdl_row.to_frame().T], ignore_index=True) for col in baseline.select_dtypes(include="int64").columns: df[col] = df[col].astype(int) df.to_csv(csv_path, index=False) - print(f"[pdl_loader] CSV updated (baseline + 1 PDL scenario). \n") + print(f"[pdl_loader] CSV updated (baseline + 1 PDL scenario row). \n") - # Build event registry for conditional runtime evaluation + # build event registry for conditional runtime evaluation event_registry = loader.to_event_registry(args.cascade) n_total = len(event_registry["events"]) - n_mapped = sum(1 for e in event_registry["events"] if e["param"] is not None) + n_shocking = sum(1 for e in event_registry["events"] if e["impacts"]) n_conditional = sum(1 for e in event_registry["events"] if e["condition"]) print(f"[event_tracker] Registry: {n_total} events " - f"({n_mapped} mapped, {n_conditional} conditional)") + f"{n_shocking} with shocks, {n_conditional} conditional)") + config = Config( diff --git a/src/provider_simenv/model.py b/src/provider_simenv/model.py index 56220e9..0144425 100644 --- a/src/provider_simenv/model.py +++ b/src/provider_simenv/model.py @@ -122,7 +122,7 @@ def setup(self): self._setup_with_role(self.processors, self.scenario.n_processors, ROLE_PROCESSOR) self._setup_with_role(self.feed_manufacturers, self.scenario.n_feed_manufacturers, ROLE_FEED_MANUFACTURER) - self._prev_shock_scales: dict[str, float] = {} + self._prev_shock_scales: dict[tuple[str, str], float] = {} self._prev_active_events: set[str] = set() self._heartbeat_interval: int = 30 @@ -150,17 +150,19 @@ def _collect_snapshot(self) -> dict: } - def _log_event(self, t: int, direction: str, param: str, snap: dict): + def _log_event(self, t: int, direction: str, key: tuple[str, str], snap: dict): """ layer 1: emit one line per shock state transition """ - value = self.environment.get_effective_value(param) + entity, field = key + label = f"{entity}/{field}" + value = self.environment.get_effective_value(entity, field) if direction == "ON": pct = (value - 1.0) * 100 sign = "+" if pct > 0 else "" - print(f" ▸ DAY {t:03d} ON {param:<28s} {value:.2f} ({sign}{pct:.0f}%)") + print(f" ▸ DAY {t:03d} ON {label:<28s} {value:.2f} ({sign}{pct:.0f}%)") else: - print(f" ▸ DAY {t:03d} OFF {param:<28s} → 1.00") + print(f" ▸ DAY {t:03d} OFF {label:<28s} → 1.00") def _log_hearbeat(self, t: int, snap: dict): @@ -208,10 +210,14 @@ def _do_step(self, t: int) -> None: for eid in sorted(activated): edef = tracker._events.get(eid) reason = f"condition: {edef.condition}" if edef and edef.condition else "unconditional" - param_str = f" -> {edef.param}={edef.value:.2f}" if edef and edef.param else "" + if edef and edef.impacts: + impact_str = " -> " + ", ".join( + f"{edef.entity}/{f}={v:.2f}" for f, v in edef.impacts.items() + ) + else: + impact_str = "" dur_str = f", expires day {t + edef.duration}" if edef and edef.duration > 0 else ", permanent" - print(f" © DAY {t:03d} EVENT ON {eid:<35s} ({reason}{param_str}{dur_str})") - + print(f" © DAY {t:03d} EVENT ON {eid:<35s} ({reason}{impact_str}{dur_str})") for eid in sorted(expired): print(f" ® DAY {t:03d} EVENT OFF {eid:<35s} (duration elapsed)") @@ -249,12 +255,12 @@ def _do_step(self, t: int) -> None: current_scales = dict(self.environment.shock_scales) # layer 1: detect transitions - for param, scale in current_scales.items(): - prev = self._prev_shock_scales.get(param, 0.0) + for key, scale in current_scales.items(): + prev = self._prev_shock_scales.get(key, 0.0) if prev == 0.0 and scale > 0.0: - self._log_event(t, "ON", param, snap) + self._log_event(t, "ON", key, snap) elif prev > 0.0 and scale == 0.0: - self._log_event(t, "OFF", param, snap) + self._log_event(t, "OFF", key, snap) self._prev_shock_scales = current_scales diff --git a/src/provider_simenv/pdl_loader.py b/src/provider_simenv/pdl_loader.py index ec6933b..d18c3ce 100644 --- a/src/provider_simenv/pdl_loader.py +++ b/src/provider_simenv/pdl_loader.py @@ -9,50 +9,18 @@ supply impact: capacity = 1.0 + pct / 100 e.g. "-40%" -> 0.60 price impact: price_factor = 1.0 + pct / 100 e.g. "+200%" -> 3.0 -Aggregation (when multiple events target the same entity) ----------------------------------------------------------- - capacity params -> take min() # worst case supply degradation - price params -> take max() # worst case price spike +No translation to model param names happens here. Each event carries its PDL target entity +and an impacts dict keyed by the PDL impact field. +The (entity, field) pair is passed through unchanged. Aggregation across events that share a key is +the EvenTracker's ob (shock_registry.aggregate) """ -import warnings from pathlib import Path -from traceback import format_exc import yaml -# mapping table -# --------------------------------------------------------------------------- -# Key : (PDL entity id, impact field) - "supply" or "price" -# Value : SupplyChainScenario field name -# -# Only 5 parameters the current simenv model can consume are listed. -# Both santos_port and paranagua_port map to the same param (port_capacity_sa) -# when both appear the min() is taken -#---------------------------------------------------------------------------- - -_PDL_MAPPING: dict[tuple[str, str], str] = { - ("brazil_farms", "supply"): "farm_capacity_bra", - ("argentina_farms", "supply"): "farm_capacity_arg", - ("santos_port", "supply"): "port_capacity_santos", - ("paranagua_port", "supply"): "port_capacity_paranagua", - ("rotterdam_port", "supply"): "port_capacity_rotterdam", - ("hamburg_port", "supply"): "port_capacity_hamburg", - ("gas_supply", "price"): "energy_price_factor", - ("fertilizer_supply", "price"): "fertilizer_price_factor", - ("eu_oil_mills", "supply"): "oil_mill_capacity", -} - -# params where lower = worse (capacity degradation) -> aggregate with min() -_CAPACITY_PARAMS = {"farm_capacity_bra", "farm_capacity_arg", - "port_capacity_santos", "port_capacity_paranagua", - "port_capacity_rotterdam", "port_capacity_hamburg", - "oil_mill_capacity", "feed_mill_capacity"} - -# params where higher = worse (price spikes) -> aggregate with max() -_PRICE_PARAMS = {"energy_price_factor", "fertilizer_price_factor"} # ------- @@ -105,59 +73,6 @@ def __init__(self, path: str | Path) -> None: self.label: str = self._doc.get("scenario", {}).get("name", self.path.stem) - def to_scenario_overrides(self) -> dict[str, float]: - """ - Scan PDL events and return scenario parameters overrides. - - Only the 5 parameters th current simenv model can consume are populated. - All other PDL events are silently skipped on the current version. - - Returns - -------- - dict[str, float] - e.g. {"farm_capacity_bra": 0.60, "port_capacity_sa": 0.80, ...} - """ - - # collect all candidate values per scenario param name - candidates: dict[str, list[float]] = {} - - events: list[dict] = self._doc.get("events") or [] - - for event in events: - target: str = (event.get("trigger") or {}).get("target", "") - impact: dict = event.get("impact") or {} - - for field in ("supply", "price"): - raw = impact.get(field) - if raw is None: - continue - - param = _PDL_MAPPING.get((target, field)) - if param is None: - continue # not a mapped combination, skip silently - - pct = _parse_percent(str(raw)) - value = round(1.0 + pct / 100.0, 6) - candidates.setdefault(param, []).append(value) - - # reduce candidates to a single value per param - overrides: dict[str, float] = {} - for param, values in candidates.items(): - if param in _CAPACITY_PARAMS: - overrides[param] = min(values) # worst-case supply - elif param in _PRICE_PARAMS: - overrides[param] = max(values) # worst-case price - else: - overrides[param] = values[0] - - if not overrides: - warnings.warn( - f"PDLLoader: no mappable events found in '{self.path.name}'." - "Check that entity IDs match the expected PDL mapping.", - stacklevel=2, - ) - - return overrides def _get_cascade(self, cascade_id: str | None) -> dict: @@ -178,85 +93,14 @@ def _get_cascade(self, cascade_id: str | None) -> dict: ) - def _build_event_index(self) -> dict[str, dict]: - """ - Return a dict mapping event id -> event dict for fast lookup. - """ - return {e["id"]: e for e in (self._doc.get("events") or [])} - - - def to_cascade_schedule(self, cascade_id: str | None = None) -> dict[str, dict[str, int]]: - """ - Parse a cascade timeline and return per-parameter shock schedule. - - For each scenario parameter the current model can consume, - returns the onset day (when the shock starts) and end day (when it ends). - Onset comes from the cascade timeline's 'at:' field. - End is onset + the matching event's 'impact.duration'. - - When multiple timeline entries map to the same scenario parameter: - onset = min(all onset days) -> shock starts at earliest trigger - end = max(all end days) -> shock lasts until latest event expires - - e.g. - - {at: 14d, event: soy_export_reduction} -> santos_port -> port_capacity_santos - - {at: 21d, event: port_congestion} -> santos_port -> port_capacity_santos - - :param cascade_id: - cascade_id: str | None - ID of the cascade to read (e.g. "soy_crisis_cascade") - If None, the first cascae in the file is used. - - :return: - dict[str, dict[str, int]] - e.g. - { - "farm_capacity_bra": {"onset": 0, "end": 90}, - "port_capacity_santos": {"onset": 14, "end": 134}, - } - """ - cascade = self._get_cascade(cascade_id) - event_index = self._build_event_index() - timeline = cascade.get("timeline") or [] - - candidates: dict[str, list[tuple[int, int]]] = {} - - for entry in timeline: - onset_day = _parse_duration(entry.get("at", "0d")) - event_id = entry.get("event", "") - event = event_index.get(event_id) - if event is None: - continue - - target = (event.get("trigger") or {}).get("target", "") - impact = event.get("impact") or {} - duration_raw = impact.get("duration") - duration_days = _parse_duration(duration_raw) if duration_raw else 0 - end_day = onset_day + duration_days - - for field in ("supply", "price"): - if impact.get(field) is None: - continue - param = _PDL_MAPPING.get((target, field)) - if param is None: - continue - candidates.setdefault(param, []).append((onset_day, end_day)) - - # earliest onset, latest end - return { - param: { - "onset": min(p[0] for p in pairs), - "end": max(p[1] for p in pairs), - } - for param, pairs in candidates.items() - } - - def to_event_registry(self, cascade_id: str | None = None) -> dict: """ - Export event definitions and cascade timeline for the EventTracker. + Export event definitions and casecade timeline for the EventTracker. - All PDL events are included, events without a simenv param mapping get param=None / value=None + Every PDL event is included. Each event carries its 'target' entity and an 'impacts' dict + mapping each present impact field to its converted multiplier, + - {"supply": 0.30, "price": 1.80} + Events with no supply/price impact get impacts={} """ events: list[dict] = [] for event in (self._doc.get("events") or []): @@ -269,30 +113,25 @@ def to_event_registry(self, cascade_id: str | None = None) -> dict: duration_raw = impact.get("duration") duration = _parse_duration(duration_raw) if duration_raw else 0 - # find the first mapped (target, field) - param = None - value = None - impact_field = "supply" - + # Carry every supply/price impact through verbatim as a multiplier, + # keyed by its PDL field. Both are emitted when both are present, + # so one entity can shock supply and price independently. + # TODO: 'demand' is parsed but skipped: nothing consumes it yet and there is + # no aggregation rule for it (future extension) + impacts: dict[str, float] = {} for field in ("supply", "price"): raw = impact.get(field) if raw is None: continue - mapped = _PDL_MAPPING.get((target, field)) - if mapped is not None: - pct = _parse_percent(str(raw)) - param = mapped - value = round(1.0 + pct / 100.0, 6) - impact_field = field - break # one param per event + pct = _parse_percent(str(raw)) + impacts[field] = round(1.0 + pct / 100.0, 6) events.append({ "id": eid, - "param": param, - "value": value, + "entity": target, + "impacts": impacts, "duration": duration, "condition": condition, - "impact_field": impact_field, }) # --- cascade timeline --- diff --git a/src/provider_simenv/scenario.py b/src/provider_simenv/scenario.py index 12b1005..0f48168 100644 --- a/src/provider_simenv/scenario.py +++ b/src/provider_simenv/scenario.py @@ -13,47 +13,8 @@ class SupplyChainScenario(Scenario): All numeric fields are read by Melodie from the scenarios table. """ - # --- KG shock coefficients --- - farm_capacity_bra: float = 1.0 # BRA soja farm output multiplier (0.7 = 30% drought loss) - - port_capacity_santos: float = 1.0 # Santos port throughput multiplier - port_capacity_paranagua: float = 1.0 # Paranagua port throughput multiplier - port_capacity_rotterdam: float = 1.0 # Rotterdam port throughput multiplier - port_capacity_hamburg: float = 1.0 # Hamburg port throughput multiplier - santos_share: float = 0.7 # 0.7 = 70% of BRA exports via Santos - - fertilizer_price_factor: float = 1.0 # multiplier on SA farmer fixed costs (1.3 = + 30%) - energy_price_factor: float = 1.0 # multiplier on all transport fixed costs (1.5 = + 50%) - - oil_mill_capacity: float = 1.0 # EU processor output multiplier (0.95 = slight soja shortage) - feed_mill_capacity: float = 1.0 # feed manufacturer output multiplier - - shock_ramp_steps: int = 0 - - # --- Per-parameter onset and end days --- - # onset: simulation day the shock activates (inclusive) - # end: simulation day the shock deactivates (exclusive) - # Active window: onset <= t < end - shock_onset_farm_bra: int = 0 - shock_end_farm_bra: int = 365 - shock_onset_farm_arg: int = 0 - shock_end_farm_arg: int = 365 - shock_onset_port_santos: int = 0 - shock_end_port_santos: int = 365 - shock_onset_port_paranagua: int = 0 - shock_end_port_paranagua: int = 365 - shock_onset_port_rotterdam: int = 0 - shock_end_port_rotterdam: int = 365 - shock_onset_port_hamburg: int = 0 - shock_end_port_hamburg: int = 365 - shock_onset_fertilizer: int = 0 - shock_end_fertilizer: int = 365 - shock_onset_energy: int = 0 - shock_end_energy: int = 365 - shock_onset_oil_mill: int = 0 - shock_end_oil_mill: int = 365 - shock_onset_feed_mill: int = 0 - shock_end_feed_mill: int = 365 + # --- Supply chain routing --- + santos_share: float = 0.7 # 70% of BRA exports via Santos # --- Agent population size --- n_bra_farmers: int = 10 # South American Farmers @@ -133,6 +94,5 @@ class SupplyChainScenario(Scenario): n_arg_farmers: int = 5 fixed_costs_arg_farmer: float = 42000.0 margin_arg_farmer: float = 0.10 - farm_capacity_arg: float = 1.0 # ARG output multiplier (1.0 = unshocked) diff --git a/src/provider_simenv/shock_registry.py b/src/provider_simenv/shock_registry.py new file mode 100644 index 0000000..160b5fd --- /dev/null +++ b/src/provider_simenv/shock_registry.py @@ -0,0 +1,19 @@ +""" +Single source of truth for the shock-parameter registry. + +To expose a new (entity, impact) to the model: add one row to BINDING. +""" +from __future__ import annotations + + + +# The (entity, impact_field) key, whose supply degradation defines drought severity. +DROUGHT_KEY: tuple[str, str] = ("brazil_farms", "supply") + +def aggregate(impact_field: str, values: list[float]) -> float: + """ + The single aggregation rule, a function of the impact fields: + supply -> min() (worst-case capacity degradation) + price -> max() (worst-case price spike) + """ + return min(values) if impact_field == "supply" else max(values) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 8420d2f..fee5bf5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,6 @@ # event_tracker and pdl_loader use bare imports # -> package directory itself must be importable -SRC = Path(__file__).resolve().parent.parent / "src" / "provider_simenv" +SRC = Path(__file__).resolve().parent.parent / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) \ No newline at end of file diff --git a/tests/test_event_tracker.py b/tests/test_event_tracker.py index 9b1c802..5b5c224 100644 --- a/tests/test_event_tracker.py +++ b/tests/test_event_tracker.py @@ -4,14 +4,17 @@ EventTracker must be stepped once per simulation day starting at 0; an event's activated_at is fixed on the first step where it is eligible, so these tests always step sequentially from day 0. + +Shocks are keyed by the PDL (entity, impact_field) pair; events with no +supply/price impact (impacts={}) are still tracked for use as conditions. """ -from event_tracker import EventTracker +from provider_simenv.event_tracker import EventTracker -def ev(eid, *, param=None, value=None, duration=0, condition="", impact_field="supply"): +def ev(eid, *, entity=None, impacts=None, duration=0, condition=""): return { - "id": eid, "param": param, "value": value, - "duration": duration, "condition": condition, "impact_field": impact_field, + "id": eid, "entity": entity, "impacts": impacts or {}, + "duration": duration, "condition": condition, } @@ -28,27 +31,27 @@ def step_through(tracker, last_day): # --- eligibility / unconditional activation --- def test_unconditional_event_activates_on_its_cascade_day(): - t = EventTracker([ev("a", param="farm_capacity_bra", value=0.6)], [tl(3, "a")]) + t = EventTracker([ev("a", entity="brazil_farms", impacts={"supply": 0.6})], [tl(3, "a")]) step_through(t, 2) assert "a" not in t.get_active_event_ids() t.step(3) assert "a" in t.get_active_event_ids() - assert t.get_param_value("farm_capacity_bra") == 0.6 - assert t.get_shock_scale("farm_capacity_bra") == 1.0 + assert t.get_param_value("brazil_farms", "supply") == 0.6 + assert t.get_shock_scale("brazil_farms", "supply") == 1.0 # --- conditional gating --- def test_conditional_activates_when_dependency_already_active(): events = [ - ev("drought", param="farm_capacity_bra", value=0.6), - ev("export", param="port_capacity_santos", value=0.88, condition="drought.active"), + ev("drought", entity="brazil_farms", impacts={"supply": 0.6}), + ev("export", entity="santos_port", impacts={"supply": 0.88}, condition="drought.active"), ] t = EventTracker(events, [tl(0, "drought"), tl(5, "export")]) step_through(t, 5) assert "drought" in t.get_active_event_ids() assert "export" in t.get_active_event_ids() # drought active since day 0 - assert t.get_param_value("port_capacity_santos") == 0.88 + assert t.get_param_value("santos_port", "supply") == 0.88 def test_same_day_dependency_shifts_conditional_by_one_day(): @@ -56,8 +59,8 @@ def test_same_day_dependency_shifts_conditional_by_one_day(): the SAME cascade day fires one day later, because `.active` is false on the dependency's activation day.""" events = [ - ev("a", param="farm_capacity_bra", value=0.6), - ev("b", param="port_capacity_santos", value=0.8, condition="a.active"), + ev("a", entity="brazil_farms", impacts={"supply": 0.6}), + ev("b", entity="santos_port", impacts={"supply": 0.8}, condition="a.active"), ] t = EventTracker(events, [tl(0, "a"), tl(0, "b")]) t.step(0) @@ -72,8 +75,8 @@ def test_same_day_dependency_shifts_conditional_by_one_day(): def test_duration_threshold_condition(): events = [ - ev("a", param="farm_capacity_bra", value=0.5), - ev("b", param="port_capacity_santos", value=0.5, condition="a.duration > 2d"), + ev("a", entity="brazil_farms", impacts={"supply": 0.5}), + ev("b", entity="santos_port", impacts={"supply": 0.5}, condition="a.duration > 2d"), ] t = EventTracker(events, [tl(0, "a"), tl(0, "b")]) step_through(t, 2) @@ -84,8 +87,8 @@ def test_duration_threshold_condition(): def test_and_condition_requires_both(): events = [ - ev("a"), ev("c"), # unmapped (param=None) — still usable as conditions - ev("b", param="farm_capacity_bra", value=0.5, + ev("a"), ev("c"), # no impact (entity=None) — still usable as conditions + ev("b", entity="brazil_farms", impacts={"supply": 0.5}, condition="a.active AND c.active"), ] t = EventTracker(events, [tl(0, "a"), tl(5, "c"), tl(0, "b")]) @@ -99,7 +102,7 @@ def test_and_condition_requires_both(): def test_or_condition_requires_either(): events = [ ev("a"), - ev("b", param="farm_capacity_bra", value=0.5, condition="a.active OR z.active"), + ev("b", entity="brazil_farms", impacts={"supply": 0.5}, condition="a.active OR z.active"), ] t = EventTracker(events, [tl(0, "a"), tl(0, "b")]) t.step(0) @@ -111,22 +114,22 @@ def test_or_condition_requires_either(): # --- duration / expiry --- def test_event_expires_after_duration(): - t = EventTracker([ev("a", param="farm_capacity_bra", value=0.5, duration=5)], [tl(0, "a")]) + t = EventTracker([ev("a", entity="brazil_farms", impacts={"supply": 0.5}, duration=5)], [tl(0, "a")]) step_through(t, 4) assert "a" in t.get_active_event_ids() t.step(5) # day >= activated_at + duration assert "a" not in t.get_active_event_ids() - assert t.get_param_value("farm_capacity_bra") == 1.0 + assert t.get_param_value("brazil_farms", "supply") == 1.0 def test_zero_duration_is_permanent(): - t = EventTracker([ev("a", param="farm_capacity_bra", value=0.5, duration=0)], [tl(0, "a")]) + t = EventTracker([ev("a", entity="brazil_farms", impacts={"supply": 0.5}, duration=0)], [tl(0, "a")]) step_through(t, 500) assert "a" in t.get_active_event_ids() def test_expired_event_does_not_reactivate(): - t = EventTracker([ev("a", param="farm_capacity_bra", value=0.5, duration=3)], [tl(0, "a")]) + t = EventTracker([ev("a", entity="brazil_farms", impacts={"supply": 0.5}, duration=3)], [tl(0, "a")]) for d in range(20): t.step(d) if d >= 3: @@ -135,45 +138,45 @@ def test_expired_event_does_not_reactivate(): # --- aggregation --- -def test_capacity_params_aggregate_with_min(): +def test_supply_aggregates_with_min(): events = [ - ev("a", param="port_capacity_santos", value=0.8), - ev("b", param="port_capacity_santos", value=0.6), + ev("a", entity="santos_port", impacts={"supply": 0.8}), + ev("b", entity="santos_port", impacts={"supply": 0.6}), ] t = EventTracker(events, [tl(0, "a"), tl(0, "b")]) t.step(0) - assert t.get_param_value("port_capacity_santos") == 0.6 - assert t.get_shock_scale("port_capacity_santos") == 1.0 + assert t.get_param_value("santos_port", "supply") == 0.6 + assert t.get_shock_scale("santos_port", "supply") == 1.0 -def test_price_params_aggregate_with_max(): +def test_price_aggregates_with_max(): events = [ - ev("a", param="energy_price_factor", value=1.5), - ev("b", param="energy_price_factor", value=3.0), + ev("a", entity="gas_supply", impacts={"price": 1.5}), + ev("b", entity="gas_supply", impacts={"price": 3.0}), ] t = EventTracker(events, [tl(0, "a"), tl(0, "b")]) t.step(0) - assert t.get_param_value("energy_price_factor") == 3.0 + assert t.get_param_value("gas_supply", "price") == 3.0 -# --- defaults / unmapped events --- +# --- defaults / impact-less events --- -def test_unknown_param_defaults_to_baseline(): +def test_unknown_key_defaults_to_baseline(): t = EventTracker([], []) t.step(0) - assert t.get_param_value("anything") == 1.0 - assert t.get_shock_scale("anything") == 0.0 + assert t.get_param_value("anything", "supply") == 1.0 + assert t.get_shock_scale("anything", "supply") == 0.0 -def test_unmapped_event_tracked_for_conditions_only(): +def test_impactless_event_tracked_for_conditions_only(): events = [ - ev("trigger"), # param=None -> produces no shock value - ev("dependent", param="farm_capacity_bra", value=0.7, condition="trigger.active"), + ev("trigger"), # impacts={} -> produces no shock value + ev("dependent", entity="brazil_farms", impacts={"supply": 0.7}, condition="trigger.active"), ] t = EventTracker(events, [tl(0, "trigger"), tl(0, "dependent")]) t.step(0) assert "trigger" in t.get_active_event_ids() - assert t.get_param_value("farm_capacity_bra") == 1.0 # trigger has no param + assert t.get_param_value("brazil_farms", "supply") == 1.0 # trigger has no impact t.step(1) assert "dependent" in t.get_active_event_ids() # +1-day shift, then active - assert t.get_param_value("farm_capacity_bra") == 0.7 \ No newline at end of file + assert t.get_param_value("brazil_farms", "supply") == 0.7 \ No newline at end of file diff --git a/tests/test_pdl_event_registry.py b/tests/test_pdl_event_registry.py index d2f6cf5..d20552b 100644 --- a/tests/test_pdl_event_registry.py +++ b/tests/test_pdl_event_registry.py @@ -1,9 +1,13 @@ """ Validates PDLLoader.to_event_registry() against the real s1-soja.pdl.yaml. + +Events carry the PDL `target` entity and an `impacts` dict keyed by impact +field — no translation to model param names. `demand` impacts are +parsed-but-skipped (nothing consumes them yet). """ from pathlib import Path -from pdl_loader import PDLLoader +from provider_simenv.pdl_loader import PDLLoader PDL_PATH = ( Path(__file__).resolve().parent.parent @@ -15,13 +19,28 @@ def test_event_registry_counts(): reg = PDLLoader(PDL_PATH).to_event_registry() # default cascade = soy_crisis_cascade events = reg["events"] assert len(events) == 18 - assert sum(1 for e in events if e["param"] is not None) == 8 # mapped - assert sum(1 for e in events if e["condition"]) == 15 # conditional + assert sum(1 for e in events if e["impacts"]) == 16 # carry a supply/price impact + assert sum(1 for e in events if e["condition"]) == 15 # conditional assert len(reg["timeline"]) == 13 -def test_argentina_supply_increase_is_mapped(): +def test_argentina_supply_increase_carries_both_impacts(): reg = PDLLoader(PDL_PATH).to_event_registry() arg = next(e for e in reg["events"] if e["id"] == "argentina_supply_increase") - assert arg["param"] == "farm_capacity_arg" # NOT unmapped (HTML brief is wrong) - assert arg["value"] == 1.1 # +10% -> 1.10 \ No newline at end of file + assert arg["entity"] == "argentina_farms" + # supply +10% AND price +15% — both carried, no collision + assert arg["impacts"] == {"supply": 1.1, "price": 1.15} + + +def test_demand_only_event_has_empty_impacts(): + reg = PDLLoader(PDL_PATH).to_event_registry() + sub = next(e for e in reg["events"] if e["id"] == "consumer_substitution") + assert sub["entity"] == "food_retail" + assert sub["impacts"] == {} # demand -8% is parsed-but-skipped + + +def test_demand_and_price_event_keeps_only_price(): + reg = PDLLoader(PDL_PATH).to_event_registry() + fert = next(e for e in reg["events"] if e["id"] == "fertilizer_demand_spike") + assert fert["entity"] == "fertilizer_supply" + assert fert["impacts"] == {"price": 1.8} # demand +40% skipped, price +80% kept \ No newline at end of file diff --git a/tests/test_shock_registry.py b/tests/test_shock_registry.py new file mode 100644 index 0000000..535a080 --- /dev/null +++ b/tests/test_shock_registry.py @@ -0,0 +1,126 @@ +""" +acceptance: shocks are keyed by the PDL (entity, impact_field) pair, passed +through verbatim — there is no translation layer. + +Proves: + 1. shock_registry exposes only the aggregation rule + the drought anchor + (no BINDING / param_names / drought_param translation table). + 2. A NEW shock can be driven purely from a PDL file — zero Python edits — + because the (entity, field) key comes straight from the PDL. Uses + rotterdam_port, absent from the shipped s1-soja.pdl.yaml. + 3. Aggregation is a function of the impact field (supply -> min, price -> max). + 4. One entity carrying BOTH supply and price impacts resolves each + independently — no collision (the case an entity-only key would fail). +""" +import textwrap +from pathlib import Path + +from provider_simenv import shock_registry as reg +from provider_simenv.pdl_loader import PDLLoader +from provider_simenv.event_tracker import EventTracker + + +# --- 1. registry is only rule + anchor, no translation table --- + +def test_no_translation_table(): + assert not hasattr(reg, "BINDING") + assert not hasattr(reg, "param_for") + assert not hasattr(reg, "param_names") + assert not hasattr(reg, "drought_param") + + +def test_drought_key_is_a_pdl_entity_pair(): + assert reg.DROUGHT_KEY == ("brazil_farms", "supply") + + +def test_aggregate_rule_by_field(): + assert reg.aggregate("supply", [0.8, 0.5]) == 0.5 # worst-case supply + assert reg.aggregate("price", [1.2, 1.5]) == 1.5 # worst-case price + + +# --- 2. a new shock, driven entirely from PDL (zero Python edits) --- + +PDL_NEW_SHOCK = textwrap.dedent(""" + scenario: + name: rotterdam_strike_test + events: + - id: rotterdam_strike + trigger: + target: rotterdam_port + impact: + supply: "-25%" + duration: 30d + cascades: + - id: test_cascade + timeline: + - {at: 5d, event: rotterdam_strike} +""") + + +def _tracker_from_pdl(pdl_text: str, tmp_path: Path) -> EventTracker: + path = tmp_path / "new_shock.pdl.yaml" + path.write_text(pdl_text, encoding="utf-8") + registry = PDLLoader(path).to_event_registry() + return EventTracker(events=registry["events"], timeline=registry["timeline"]) + + +def test_new_pdl_shock_with_zero_code(tmp_path): + t = _tracker_from_pdl(PDL_NEW_SHOCK, tmp_path) + key = ("rotterdam_port", "supply") + + # before onset (day 5): unshocked + for d in range(5): + t.step(d) + assert t.get_shock_scale(*key) == 0.0 + assert t.get_param_value(*key) == 1.0 + + # at onset: -25% supply -> 0.75 + t.step(5) + assert t.get_shock_scale(*key) == 1.0 + assert t.get_param_value(*key) == 0.75 + + # after duration (5 + 30): expired, back to unshocked + t.step(35) + assert t.get_shock_scale(*key) == 0.0 + assert t.get_param_value(*key) == 1.0 + + +# --- 3. aggregation derived from impact field --- + +def _ev(eid, *, entity, impacts): + return {"id": eid, "entity": entity, "impacts": impacts, + "duration": 0, "condition": ""} + + +def test_supply_aggregates_with_min(): + t = EventTracker( + [_ev("a", entity="santos_port", impacts={"supply": 0.8}), + _ev("b", entity="santos_port", impacts={"supply": 0.5})], + [{"at_day": 0, "event_id": "a"}, {"at_day": 0, "event_id": "b"}], + ) + t.step(0) + assert t.get_param_value("santos_port", "supply") == 0.5 # worst-case supply + + +def test_price_aggregates_with_max(): + t = EventTracker( + [_ev("a", entity="gas_supply", impacts={"price": 1.2}), + _ev("b", entity="gas_supply", impacts={"price": 1.5})], + [{"at_day": 0, "event_id": "a"}, {"at_day": 0, "event_id": "b"}], + ) + t.step(0) + assert t.get_param_value("gas_supply", "price") == 1.5 # worst-case price + + +# --- 4. one entity, both impacts, no collision (the B1-failure case) --- + +def test_one_event_both_impacts_resolves_independently(): + # mirrors argentina_supply_increase: supply +10% AND price +15% on one entity. + t = EventTracker( + [_ev("arg_increase", entity="argentina_farms", + impacts={"supply": 1.10, "price": 1.15})], + [{"at_day": 0, "event_id": "arg_increase"}], + ) + t.step(0) + assert t.get_param_value("argentina_farms", "supply") == 1.10 + assert t.get_param_value("argentina_farms", "price") == 1.15 \ No newline at end of file