From 525ff421493eb8d7897ed75c260734921bb30256 Mon Sep 17 00:00:00 2001 From: whning Date: Sat, 6 Jun 2026 02:28:45 +0800 Subject: [PATCH] Fix mutable default arguments across contrib modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge 3 individual fixes into one module-level PR: 1. contrib handlers (8 files) — data.py, handler.py, highfreq_handler.py, etc. 2. contrib models (3 files) — pytorch_adarnn.py, pytorch_general_nn.py, pytorch_nn.py 3. contrib strategy (4 files) — evaluate.py, enhanced_indexing.py, signal_strategy.py, analysis_model_performance.py All follow standard Python mutable-default fix pattern. --- qlib/contrib/data/data.py | 4 +++- qlib/contrib/data/handler.py | 16 +++++++++---- qlib/contrib/data/highfreq_handler.py | 24 ++++++++++++++----- qlib/contrib/data/highfreq_provider.py | 8 +++++-- qlib/contrib/data/loader.py | 20 ++++++++-------- qlib/contrib/evaluate.py | 4 +++- qlib/contrib/model/pytorch_adarnn.py | 4 +++- qlib/contrib/model/pytorch_general_nn.py | 14 ++++++----- qlib/contrib/model/pytorch_nn.py | 10 ++++---- .../analysis_model_performance.py | 4 +++- .../strategy/optimizer/enhanced_indexing.py | 6 ++--- qlib/contrib/strategy/signal_strategy.py | 8 +++++-- qlib/data/dataset/__init__.py | 4 +++- qlib/data/dataset/handler.py | 13 +++++++--- qlib/data/dataset/loader.py | 5 +++- 15 files changed, 98 insertions(+), 46 deletions(-) diff --git a/qlib/contrib/data/data.py b/qlib/contrib/data/data.py index c153cfb8f6d..c8ae156dacb 100644 --- a/qlib/contrib/data/data.py +++ b/qlib/contrib/data/data.py @@ -18,8 +18,10 @@ class ArcticFeatureProvider(FeatureProvider): def __init__( - self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")] + self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=None ): + if market_transaction_time_list is None: + market_transaction_time_list = [("09:15", "11:30"), ("13:00", "15:00")] super().__init__() self.uri = uri # TODO: diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 2fe5258daa7..4ded07b01db 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -52,14 +52,18 @@ def __init__( start_time=None, end_time=None, freq="day", - infer_processors=_DEFAULT_INFER_PROCESSORS, - learn_processors=_DEFAULT_LEARN_PROCESSORS, + infer_processors=None, + learn_processors=None, fit_start_time=None, fit_end_time=None, filter_pipe=None, inst_processors=None, **kwargs, ): + if infer_processors is None: + infer_processors = _DEFAULT_INFER_PROCESSORS + if learn_processors is None: + learn_processors = _DEFAULT_LEARN_PROCESSORS infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) @@ -102,8 +106,8 @@ def __init__( start_time=None, end_time=None, freq="day", - infer_processors=[], - learn_processors=_DEFAULT_LEARN_PROCESSORS, + infer_processors=None, + learn_processors=None, fit_start_time=None, fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, @@ -111,6 +115,10 @@ def __init__( inst_processors=None, **kwargs, ): + if infer_processors is None: + infer_processors = [] + if learn_processors is None: + learn_processors = _DEFAULT_LEARN_PROCESSORS infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) diff --git a/qlib/contrib/data/highfreq_handler.py b/qlib/contrib/data/highfreq_handler.py index 8eed4814f2c..5a19e9f47d5 100644 --- a/qlib/contrib/data/highfreq_handler.py +++ b/qlib/contrib/data/highfreq_handler.py @@ -11,12 +11,16 @@ def __init__( instruments="csi300", start_time=None, end_time=None, - infer_processors=[], - learn_processors=[], + infer_processors=None, + learn_processors=None, fit_start_time=None, fit_end_time=None, drop_raw=True, ): + if infer_processors is None: + infer_processors = [] + if learn_processors is None: + learn_processors = [] infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) @@ -106,8 +110,8 @@ def __init__( instruments="csi300", start_time=None, end_time=None, - infer_processors=[], - learn_processors=[], + infer_processors=None, + learn_processors=None, fit_start_time=None, fit_end_time=None, drop_raw=True, @@ -116,6 +120,10 @@ def __init__( columns=["$open", "$high", "$low", "$close", "$vwap"], inst_processors=None, ): + if infer_processors is None: + infer_processors = [] + if learn_processors is None: + learn_processors = [] self.day_length = day_length self.columns = columns @@ -310,13 +318,17 @@ def __init__( instruments="csi300", start_time=None, end_time=None, - infer_processors=[], - learn_processors=[], + infer_processors=None, + learn_processors=None, fit_start_time=None, fit_end_time=None, inst_processors=None, drop_raw=True, ): + if infer_processors is None: + infer_processors = [] + if learn_processors is None: + learn_processors = [] infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) diff --git a/qlib/contrib/data/highfreq_provider.py b/qlib/contrib/data/highfreq_provider.py index 611e30d861f..4c741901e1a 100644 --- a/qlib/contrib/data/highfreq_provider.py +++ b/qlib/contrib/data/highfreq_provider.py @@ -121,7 +121,9 @@ def _prepare_calender_cache(self): Cal.calendar(freq=self.freq) get_calendar_day(freq=self.freq) - def _gen_dataframe(self, config, datasets=["train", "valid", "test"]): + def _gen_dataframe(self, config, datasets=None): + if datasets is None: + datasets = ["train", "valid", "test"] try: path = config.pop("path") except KeyError as e: @@ -163,7 +165,9 @@ def _gen_dataframe(self, config, datasets=["train", "valid", "test"]): self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}") return res - def _gen_data(self, config, datasets=["train", "valid", "test"]): + def _gen_data(self, config, datasets=None): + if datasets is None: + datasets = ["train", "valid", "test"] try: path = config.pop("path") except KeyError as e: diff --git a/qlib/contrib/data/loader.py b/qlib/contrib/data/loader.py index 4d11f3a34cc..58faa60564e 100644 --- a/qlib/contrib/data/loader.py +++ b/qlib/contrib/data/loader.py @@ -70,16 +70,7 @@ def __init__(self, config=None, **kwargs): super().__init__(config=_config, **kwargs) @staticmethod - def get_feature_config( - config={ - "kbar": {}, - "price": { - "windows": [0], - "feature": ["OPEN", "HIGH", "LOW", "VWAP"], - }, - "rolling": {}, - } - ): + def get_feature_config(config=None): """create factors from config config = { @@ -99,6 +90,15 @@ def get_feature_config( } } """ + if config is None: + config = { + "kbar": {}, + "price": { + "windows": [0], + "feature": ["OPEN", "HIGH", "LOW", "VWAP"], + }, + "rolling": {}, + } fields = [] names = [] if "kbar" in config: diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index d315622fcc6..cae8e820d76 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -283,7 +283,7 @@ def long_short_backtest( trade_unit=None, limit_threshold=None, min_cost=5, - subscribe_fields=[], + subscribe_fields=None, extract_codes=False, ): """ @@ -307,6 +307,8 @@ def long_short_backtest( "short": short_returns(excess), "long_short": long_short_returns} """ + if subscribe_fields is None: + subscribe_fields = [] if get_level_index(pred, level="datetime") == 1: pred = pred.swaplevel().sort_index() diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index c1585a6ac0a..cf8576826ef 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -379,7 +379,7 @@ def __init__( use_bottleneck=False, bottleneck_width=256, n_input=128, - n_hiddens=[64, 64], + n_hiddens=None, n_output=6, dropout=0.0, len_seq=9, @@ -387,6 +387,8 @@ def __init__( trans_loss="mmd", GPU=0, ): + if n_hiddens is None: + n_hiddens = [64, 64] super(AdaRNN, self).__init__() self.use_bottleneck = use_bottleneck self.n_input = n_input diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 503c5a2a50c..4a7d1e335b6 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -63,13 +63,15 @@ def __init__( GPU=0, seed=None, pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", - pt_model_kwargs={ - "d_feat": 6, - "hidden_size": 64, - "num_layers": 2, - "dropout": 0.0, - }, + pt_model_kwargs=None, ): + if pt_model_kwargs is None: + pt_model_kwargs = { + "d_feat": 6, + "hidden_size": 64, + "num_layers": 2, + "dropout": 0.0, + } # Set logger. self.logger = get_module_logger("GeneralPTNN") self.logger.info("GeneralPTNN pytorch version...") diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 9f427bd94d7..8a5f954f434 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -71,13 +71,15 @@ def __init__( init_model=None, eval_train_metric=False, pt_model_uri="qlib.contrib.model.pytorch_nn.Net", - pt_model_kwargs={ - "input_dim": 360, - "layers": (256,), - }, + pt_model_kwargs=None, valid_key=DataHandlerLP.DK_L, # TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing ): + if pt_model_kwargs is None: + pt_model_kwargs = { + "input_dim": 360, + "layers": (256,), + } # Set logger. self.logger = get_module_logger("DNNModelPytorch") self.logger.info("DNN pytorch version...") diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index cac1f1b8eea..3e22e7a27a0 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -296,7 +296,7 @@ def model_performance_graph( N: int = 5, reverse=False, rank=False, - graph_names: list = ["group_return", "pred_ic", "pred_autocorr"], + graph_names: list = None, show_notebook: bool = True, show_nature_day: bool = False, **kwargs, @@ -328,6 +328,8 @@ def model_performance_graph( - `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list. """ + if graph_names is None: + graph_names = ["group_return", "pred_ic", "pred_autocorr"] figure_list = [] for graph_name in graph_names: fun_res = eval(f"_{graph_name}")( diff --git a/qlib/contrib/strategy/optimizer/enhanced_indexing.py b/qlib/contrib/strategy/optimizer/enhanced_indexing.py index 4d861501f91..3292a0cfb62 100644 --- a/qlib/contrib/strategy/optimizer/enhanced_indexing.py +++ b/qlib/contrib/strategy/optimizer/enhanced_indexing.py @@ -51,7 +51,7 @@ def __init__( f_dev: Optional[Union[List[float], np.ndarray]] = None, scale_return: bool = True, epsilon: float = 5e-5, - solver_kwargs: Optional[Dict[str, Any]] = {}, + solver_kwargs: Optional[Dict[str, Any]] = None, ): """ Args: @@ -63,8 +63,8 @@ def __init__( epsilon (float): minimum weight solver_kwargs (dict): kwargs for cvxpy solver """ - - assert lamb >= 0, "risk aversion parameter `lamb` should be positive" + if solver_kwargs is None: + solver_kwargs = {} self.lamb = lamb assert delta >= 0, "turnover limit `delta` should be positive" diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index bad19ddfdc9..d957a6c9216 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -409,11 +409,15 @@ def __init__( riskmodel_root, market="csi500", turn_limit=None, - name_mapping={}, - optimizer_kwargs={}, + name_mapping=None, + optimizer_kwargs=None, verbose=False, **kwargs, ): + if name_mapping is None: + name_mapping = {} + if optimizer_kwargs is None: + optimizer_kwargs = {} super().__init__(**kwargs) self.logger = get_module_logger("EnhancedIndexingStrategy") diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index a6cace3730f..30c22c465c1 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -85,7 +85,7 @@ def __init__( self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], - fetch_kwargs: Dict = {}, + fetch_kwargs: Dict = None, **kwargs, ): """ @@ -116,6 +116,8 @@ def __init__( 'outsample': ("2017-01-01", "2020-08-01",), } """ + if fetch_kwargs is None: + fetch_kwargs = {} self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() self.fetch_kwargs = copy(fetch_kwargs) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2c45aec687f..46c164cd3b6 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -439,9 +439,9 @@ def __init__( start_time=None, end_time=None, data_loader: Union[dict, str, DataLoader] = None, - infer_processors: List = [], - learn_processors: List = [], - shared_processors: List = [], + infer_processors: List = None, + learn_processors: List = None, + shared_processors: List = None, process_type=PTYPE_A, drop_raw=False, **kwargs, @@ -489,6 +489,13 @@ def __init__( Whether to drop the raw data """ + if infer_processors is None: + infer_processors = [] + if learn_processors is None: + learn_processors = [] + if shared_processors is None: + shared_processors = [] + # Setup preprocessor self.infer_processors = [] # for lint self.learn_processors = [] # for lint diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 2f3615a6357..ca739b14749 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -360,7 +360,7 @@ class DataLoaderDH(DataLoader): - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook. """ - def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False): + def __init__(self, handler_config: dict, fetch_kwargs: dict = None, is_group=False): """ Parameters ---------- @@ -386,6 +386,9 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False """ from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415 + if fetch_kwargs is None: + fetch_kwargs = {} + if is_group: self.handlers = { grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()