Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion qlib/contrib/meta/data_selection/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, task_tpl: dict, step: int, exp_name: str):
self.step = step
self.exp_name = exp_name

def setup(self, trainer=TrainerR, trainer_kwargs={}):
def setup(self, trainer=TrainerR, trainer_kwargs=None):
"""
after running this function `self.data_ic_df` will become set.
Each col represents a data.
Expand All @@ -47,6 +47,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}):

"""

if trainer_kwargs is None:
trainer_kwargs = {}

# 1) prepare the prediction of proxy models
perf_task_tpl = deepcopy(self.task_tpl) # this task is supposed to contains no complicated objects
# The only thing we want to save is the prediction
Expand Down
5 changes: 4 additions & 1 deletion qlib/data/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class DataLoaderDH(DataLoader):
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""

def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
def __init__(self, handler_config: dict, fetch_kwargs: dict = None, is_group=False):
"""
Parameters
----------
Expand All @@ -386,6 +386,9 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False
"""
from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415

if fetch_kwargs is None:
fetch_kwargs = {}

if is_group:
self.handlers = {
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
Expand Down
18 changes: 15 additions & 3 deletions qlib/tests/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,22 @@ def get_dataset_config(
}


def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
def get_gbdt_task(dataset_kwargs=None, handler_kwargs=None):
if dataset_kwargs is None:
dataset_kwargs = {}
if handler_kwargs is None:
handler_kwargs = {"instruments": CSI300_MARKET}
return {
"model": GBDT_MODEL,
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
}


def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
def get_record_lgb_config(dataset_kwargs=None, handler_kwargs=None):
if dataset_kwargs is None:
dataset_kwargs = {}
if handler_kwargs is None:
handler_kwargs = {"instruments": CSI300_MARKET}
return {
"model": {
"class": "LGBModel",
Expand All @@ -109,7 +117,11 @@ def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI3
}


def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
def get_record_xgboost_config(dataset_kwargs=None, handler_kwargs=None):
if dataset_kwargs is None:
dataset_kwargs = {}
if handler_kwargs is None:
handler_kwargs = {"instruments": CSI300_MARKET}
return {
"model": {
"class": "XGBModel",
Expand Down
4 changes: 3 additions & 1 deletion qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,9 @@ def register_wrapper(wrapper, cls_or_obj, module_path=None):
wrapper.register(obj)


def load_dataset(path_or_obj, index_col=[0, 1]):
def load_dataset(path_or_obj, index_col=None):
if index_col is None:
index_col = [0, 1]
"""load dataset from multiple file formats"""
if isinstance(path_or_obj, pd.DataFrame):
return path_or_obj
Expand Down
4 changes: 3 additions & 1 deletion qlib/utils/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ def init_instance_by_config(
config: InstConf,
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
try_kwargs: Dict = None,
**kwargs,
) -> Any:
if try_kwargs is None:
try_kwargs = {}
"""
get initialized instance with config
Expand Down
4 changes: 3 additions & 1 deletion qlib/utils/resam.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def resam_ts_data(
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
method: Union[str, Callable] = "last",
method_kwargs: dict = {},
method_kwargs: dict = None,
):
if method_kwargs is None:
method_kwargs = {}
"""
Resample value from time-series data

Expand Down
10 changes: 8 additions & 2 deletions qlib/workflow/online/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _postpone_action(self):
"""
return self.status == self.STATUS_SIMULATING and self.trainer.is_delay()

def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = None):
"""
Get tasks from every strategy's first_tasks method and train them.
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
Expand All @@ -164,6 +164,8 @@ def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dic
"""
if strategies is None:
strategies = self.strategies
if model_kwargs is None:
model_kwargs = {}

models_list = []
for strategy in strategies:
Expand Down Expand Up @@ -346,14 +348,18 @@ def simulate(
self.status = self.STATUS_ONLINE
return self.get_signals()

def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
def delay_prepare(self, model_kwargs=None, signal_kwargs=None):
"""
Prepare all models and signals if something is waiting for preparation.

Args:
model_kwargs: the params for `end_train`
signal_kwargs: the params for `prepare_signals`
"""
if model_kwargs is None:
model_kwargs = {}
if signal_kwargs is None:
signal_kwargs = {}
# FIXME:
# This method is not implemented in the proper way!!!
last_models = {}
Expand Down
4 changes: 3 additions & 1 deletion qlib/workflow/online/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()

def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
def get_collector(self, process_list=None, rec_key_func=None, rec_filter_func=None, artifacts_key=None):
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.

Expand All @@ -132,6 +132,8 @@ def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_fi
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
"""
if process_list is None:
process_list = [RollingGroup()]

def rec_key(recorder):
task_config = recorder.load_object("task")
Expand Down
4 changes: 3 additions & 1 deletion qlib/workflow/task/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def generate(self, task: dict) -> List[dict]:


class MultiHorizonGenBase(TaskGen):
def __init__(self, horizon: List[int] = [5], label_leak_n=2):
def __init__(self, horizon: List[int] = None, label_leak_n=2):
"""
This task generator tries to generate tasks for different horizons based on an existing task

Expand All @@ -317,6 +317,8 @@ def __init__(self, horizon: List[int] = [5], label_leak_n=2):
- The label is the return of buying stock on `T + 1` and selling it on `T + 2`
- the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample)
"""
if horizon is None:
horizon = [5]
self.horizon = list(horizon)
self.label_leak_n = label_leak_n
self.ta = TimeAdjuster()
Expand Down