Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_exchange(
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
subscribe_fields: list = [],
subscribe_fields: list = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
Expand Down Expand Up @@ -87,6 +87,8 @@ def get_exchange(
an initialized Exchange object
"""

if subscribe_fields is None:
subscribe_fields = []
if limit_threshold is None:
limit_threshold = C.limit_threshold
if exchange is None:
Expand Down Expand Up @@ -181,7 +183,7 @@ def get_strategy_executor(
executor: Union[str, dict, object, Path],
benchmark: Optional[str] = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: dict = None,
pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]:
# NOTE:
Expand All @@ -190,6 +192,9 @@ def get_strategy_executor(
from ..strategy.base import BaseStrategy # pylint: disable=C0415
from .executor import BaseExecutor # pylint: disable=C0415

if exchange_kwargs is None:
exchange_kwargs = {}

trade_account = create_account_instance(
start_time=start_time,
end_time=end_time,
Expand Down Expand Up @@ -221,7 +226,7 @@ def backtest(
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: dict = None,
pos_type: str = "Position",
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
Expand Down Expand Up @@ -263,6 +268,8 @@ def backtest(
It is organized in a dict format

"""
if exchange_kwargs is None:
exchange_kwargs = {}
trade_strategy, trade_executor = get_strategy_executor(
start_time,
end_time,
Expand All @@ -283,7 +290,7 @@ def collect_data(
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: dict = None,
pos_type: str = "Position",
return_value: dict | None = None,
) -> Generator[object, None, None]:
Expand All @@ -296,6 +303,8 @@ def collect_data(
object
trade decision
"""
if exchange_kwargs is None:
exchange_kwargs = {}
trade_strategy, trade_executor = get_strategy_executor(
start_time,
end_time,
Expand Down
34 changes: 24 additions & 10 deletions qlib/backtest/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ class Account:
def __init__(
self,
init_cash: float = 1e9,
position_dict: dict = {},
position_dict: dict = None,
freq: str = "day",
benchmark_config: dict = {},
benchmark_config: dict = None,
pos_type: str = "Position",
port_metr_enabled: bool = True,
) -> None:
Expand All @@ -103,6 +103,10 @@ def __init__(
by default {}.
"""

if position_dict is None:
position_dict = {}
if benchmark_config is None:
benchmark_config = {}
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.benchmark_config: dict = {} # avoid no attribute error
Expand Down Expand Up @@ -306,14 +310,24 @@ def update_indicator(
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
trade_info: list = None,
inner_order_indicators: List[BaseOrderIndicator] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
indicator_config: dict = None,
) -> None:
"""update trade indicators and order indicators in each bar end"""
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`

if trade_info is None:
trade_info = []
if inner_order_indicators is None:
inner_order_indicators = []
if decision_list is None:
decision_list = []
if indicator_config is None:
indicator_config = {}
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`

# indicator is trading (e.g. high-frequency order execution) related analysis
self.indicator.reset()

Expand Down Expand Up @@ -342,10 +356,10 @@ def update_bar_end(
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
trade_info: list = None,
inner_order_indicators: List[BaseOrderIndicator] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
indicator_config: dict = None,
) -> None:
"""update account at each trading bar step

Expand Down
5 changes: 4 additions & 1 deletion qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
subscribe_fields: list = [],
subscribe_fields: list = None,
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict, None] = None,
open_cost: float = 0.0015,
Expand Down Expand Up @@ -141,6 +141,9 @@ def __init__(
if deal_price is None:
deal_price = C.deal_price

if subscribe_fields is None:
subscribe_fields = []

# we have some verbose information here. So logging is enabled
self.logger = get_module_logger("online operator")

Expand Down
10 changes: 7 additions & 3 deletions qlib/backtest/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
indicator_config: dict = {},
indicator_config: dict = None,
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
Expand Down Expand Up @@ -108,6 +108,8 @@ def __init__(
Please refer to the docs of BasePosition.settle_start
"""
self.time_per_step = time_per_step
if indicator_config is None:
indicator_config = {}
self.indicator_config = indicator_config
self.generate_portfolio_metrics = generate_portfolio_metrics
self.verbose = verbose
Expand Down Expand Up @@ -321,7 +323,7 @@ def __init__(
inner_strategy: Union[BaseStrategy, dict],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
indicator_config: dict = {},
indicator_config: dict = None,
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
Expand All @@ -346,6 +348,8 @@ def __init__(
force to align the trade_range decision
It is only for nested executor, because range_limit is given by outer strategy
"""
if indicator_config is None:
indicator_config = {}
self.inner_executor: BaseExecutor = init_instance_by_config(
inner_executor,
common_infra=common_infra,
Expand Down Expand Up @@ -530,7 +534,7 @@ def __init__(
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
indicator_config: dict = {},
indicator_config: dict = None,
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
Expand Down
5 changes: 4 additions & 1 deletion qlib/backtest/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class Position(BasePosition):
}
"""

def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None:
def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] | None = None) -> None:
"""Init position by cash and position_dict.

Parameters
Expand All @@ -262,6 +262,9 @@ def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, flo
"""
super().__init__()

if position_dict is None:
position_dict = {}

# NOTE: The position dict must be copied!!!
# Otherwise the initial value
self.init_cash = cash
Expand Down
24 changes: 17 additions & 7 deletions qlib/backtest/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ class PortfolioMetrics:
update report
"""

def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
def __init__(self, freq: str = "day", benchmark_config: dict | None = None) -> None:
"""
Parameters
----------
freq : str
frequency of trading bar, used for updating hold count of trading bar
benchmark_config : dict
benchmark_config : dict, optional
config of benchmark, may including the following arguments:
- benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
Expand Down Expand Up @@ -73,6 +73,8 @@ def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
"""

self.init_vars()
if benchmark_config is None:
benchmark_config = {}
self.init_bench(freq=freq, benchmark_config=benchmark_config)

def init_vars(self) -> None:
Expand Down Expand Up @@ -385,13 +387,15 @@ def _get_base_vol_pri(
direction: OrderDir,
decision: BaseTradeDecision,
trade_exchange: Exchange,
pa_config: dict = {},
pa_config: dict | None = None,
) -> Tuple[Optional[float], Optional[float]]:
"""
Get the base volume and price information
All the base price values are rooted from this function
"""

if pa_config is None:
pa_config = {}
agg = pa_config.get("agg", "twap").lower()
price = pa_config.get("price", "deal_price").lower()

Expand Down Expand Up @@ -457,7 +461,7 @@ def _agg_base_price(
inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
trade_exchange: Exchange,
pa_config: dict = {},
pa_config: dict | None = None,
) -> None:
"""
# NOTE:!!!!
Expand All @@ -472,7 +476,7 @@ def _agg_base_price(
a list of decisions according to inner_order_indicators
trade_exchange : Exchange
for retrieving trading price
pa_config : dict
pa_config : dict, optional
For example
{
"agg": "twap", # "vwap"
Expand All @@ -481,6 +485,8 @@ def _agg_base_price(
}
"""

if pa_config is None:
pa_config = {}
# TODO: I think there are potentials to be optimized
trade_dir = self.order_indicator.get_index_data("trade_dir")
if len(trade_dir) > 0:
Expand Down Expand Up @@ -542,8 +548,10 @@ def agg_order_indicators(
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
outer_trade_decision: BaseTradeDecision,
trade_exchange: Exchange,
indicator_config: dict = {},
indicator_config: dict | None = None,
) -> None:
if indicator_config is None:
indicator_config = {}
self._agg_order_trade_info(inner_order_indicators)
self._update_trade_amount(outer_trade_decision)
self._update_order_fulfill_rate()
Expand Down Expand Up @@ -609,8 +617,10 @@ def cal_trade_indicators(
self,
trade_start_time: Union[str, pd.Timestamp],
freq: str,
indicator_config: dict = {},
indicator_config: dict | None = None,
) -> None:
if indicator_config is None:
indicator_config = {}
show_indicator = indicator_config.get("show_indicator", False)
ffr_config = indicator_config.get("ffr_config", {})
pa_config = indicator_config.get("pa_config", {})
Expand Down