diff --git a/qlib/contrib/meta/data_selection/dataset.py b/qlib/contrib/meta/data_selection/dataset.py index 61efdd63cfb..97b339fb858 100644 --- a/qlib/contrib/meta/data_selection/dataset.py +++ b/qlib/contrib/meta/data_selection/dataset.py @@ -26,7 +26,7 @@ def __init__(self, task_tpl: dict, step: int, exp_name: str): self.step = step self.exp_name = exp_name - def setup(self, trainer=TrainerR, trainer_kwargs={}): + def setup(self, trainer=TrainerR, trainer_kwargs=None): """ after running this function `self.data_ic_df` will become set. Each col represents a data. @@ -47,6 +47,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}): """ + if trainer_kwargs is None: + trainer_kwargs = {} + # 1) prepare the prediction of proxy models perf_task_tpl = deepcopy(self.task_tpl) # this task is supposed to contains no complicated objects # The only thing we want to save is the prediction diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 2f3615a6357..ca739b14749 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -360,7 +360,7 @@ class DataLoaderDH(DataLoader): - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook. """ - def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False): + def __init__(self, handler_config: dict, fetch_kwargs: dict = None, is_group=False): """ Parameters ---------- @@ -386,6 +386,9 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False """ from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415 + if fetch_kwargs is None: + fetch_kwargs = {} + if is_group: self.handlers = { grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items() diff --git a/qlib/tests/config.py b/qlib/tests/config.py index ea1b2365945..b659515ca6e 100644 --- a/qlib/tests/config.py +++ b/qlib/tests/config.py @@ -91,14 +91,22 @@ def get_dataset_config( } -def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): +def get_gbdt_task(dataset_kwargs=None, handler_kwargs=None): + if dataset_kwargs is None: + dataset_kwargs = {} + if handler_kwargs is None: + handler_kwargs = {"instruments": CSI300_MARKET} return { "model": GBDT_MODEL, "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), } -def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): +def get_record_lgb_config(dataset_kwargs=None, handler_kwargs=None): + if dataset_kwargs is None: + dataset_kwargs = {} + if handler_kwargs is None: + handler_kwargs = {"instruments": CSI300_MARKET} return { "model": { "class": "LGBModel", @@ -109,7 +117,11 @@ def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI3 } -def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): +def get_record_xgboost_config(dataset_kwargs=None, handler_kwargs=None): + if dataset_kwargs is None: + dataset_kwargs = {} + if handler_kwargs is None: + handler_kwargs = {"instruments": CSI300_MARKET} return { "model": { "class": "XGBModel", diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 2a94ebd555b..35a1fd3655c 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -886,7 +886,9 @@ def register_wrapper(wrapper, cls_or_obj, module_path=None): wrapper.register(obj) -def load_dataset(path_or_obj, index_col=[0, 1]): +def load_dataset(path_or_obj, index_col=None): + if index_col is None: + index_col = [0, 1] """load dataset from multiple file formats""" if isinstance(path_or_obj, pd.DataFrame): return path_or_obj diff --git a/qlib/utils/mod.py b/qlib/utils/mod.py index 5cb2ed3f453..f7a03be8105 100644 --- a/qlib/utils/mod.py +++ b/qlib/utils/mod.py @@ -123,9 +123,11 @@ def init_instance_by_config( config: InstConf, default_module=None, accept_types: Union[type, Tuple[type]] = (), - try_kwargs: Dict = {}, + try_kwargs: Dict = None, **kwargs, ) -> Any: + if try_kwargs is None: + try_kwargs = {} """ get initialized instance with config diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 99aedfcd50c..e0dc6e5087c 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -104,8 +104,10 @@ def resam_ts_data( start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, method: Union[str, Callable] = "last", - method_kwargs: dict = {}, + method_kwargs: dict = None, ): + if method_kwargs is None: + method_kwargs = {} """ Resample value from time-series data diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 09e96d444f2..193cd9d7582 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -153,7 +153,7 @@ def _postpone_action(self): """ return self.status == self.STATUS_SIMULATING and self.trainer.is_delay() - def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}): + def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = None): """ Get tasks from every strategy's first_tasks method and train them. If using DelayTrainer, it can finish training all together after every strategy's first_tasks. @@ -164,6 +164,8 @@ def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dic """ if strategies is None: strategies = self.strategies + if model_kwargs is None: + model_kwargs = {} models_list = [] for strategy in strategies: @@ -346,7 +348,7 @@ def simulate( self.status = self.STATUS_ONLINE return self.get_signals() - def delay_prepare(self, model_kwargs={}, signal_kwargs={}): + def delay_prepare(self, model_kwargs=None, signal_kwargs=None): """ Prepare all models and signals if something is waiting for preparation. @@ -354,6 +356,10 @@ def delay_prepare(self, model_kwargs={}, signal_kwargs={}): model_kwargs: the params for `end_train` signal_kwargs: the params for `prepare_signals` """ + if model_kwargs is None: + model_kwargs = {} + if signal_kwargs is None: + signal_kwargs = {} # FIXME: # This method is not implemented in the proper way!!! last_models = {} diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index d545e4bc9a6..d2a85788b62 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -120,7 +120,7 @@ def __init__( self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster() - def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): + def get_collector(self, process_list=None, rec_key_func=None, rec_filter_func=None, artifacts_key=None): """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models. @@ -132,6 +132,8 @@ def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_fi rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts. """ + if process_list is None: + process_list = [RollingGroup()] def rec_key(recorder): task_config = recorder.load_object("task") diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index cf95e600633..6cc9daf0f47 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -302,7 +302,7 @@ def generate(self, task: dict) -> List[dict]: class MultiHorizonGenBase(TaskGen): - def __init__(self, horizon: List[int] = [5], label_leak_n=2): + def __init__(self, horizon: List[int] = None, label_leak_n=2): """ This task generator tries to generate tasks for different horizons based on an existing task @@ -317,6 +317,8 @@ def __init__(self, horizon: List[int] = [5], label_leak_n=2): - The label is the return of buying stock on `T + 1` and selling it on `T + 2` - the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample) """ + if horizon is None: + horizon = [5] self.horizon = list(horizon) self.label_leak_n = label_leak_n self.ta = TimeAdjuster()