Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion qlib/backtest/high_performance_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ def __len__(self):
class PandasSingleMetric(SingleMetric):
"""Each SingleMetric is based on pd.Series."""

def __init__(self, metric: Union[dict, pd.Series] = {}):
def __init__(self, metric: Union[dict, pd.Series] | None = None):
if metric is None:
metric = {}
if isinstance(metric, dict):
self.metric = pd.Series(metric)
elif isinstance(metric, pd.Series):
Expand Down
41 changes: 30 additions & 11 deletions qlib/data/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class DatasetCache(BaseProviderCache):
HDF_KEY = "df"

def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
):
"""Get feature dataset.

Expand All @@ -399,6 +399,8 @@ def dataset(
read-write conflicts will not be triggered
but client readers are not considered.
"""
if inst_processors is None:
inst_processors = []
if disk_cache == 0:
# skip cache
return self.provider.dataset(
Expand All @@ -423,7 +425,7 @@ def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):
raise NotImplementedError("Implement this function to match your own cache mechanism")

def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
):
"""Get feature dataset using cache.

Expand All @@ -432,7 +434,7 @@ def _dataset(
raise NotImplementedError("Implement this method if you want to use dataset feature cache")

def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
):
"""Get a uri of feature dataset using cache.
specially:
Expand Down Expand Up @@ -653,7 +655,9 @@ def __init__(self, provider, **kwargs):
self.remote = kwargs.get("remote", False)

@staticmethod
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
if inst_processors is None:
inst_processors = []
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)

def get_cache_dir(self, freq: str = None) -> Path:
Expand Down Expand Up @@ -694,8 +698,10 @@ def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time
return df

def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
):
if inst_processors is None:
inst_processors = []
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(
Expand Down Expand Up @@ -748,8 +754,10 @@ def _dataset(
return features

def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
):
if inst_processors is None:
inst_processors = []
if disk_cache == 0:
# In this case, server only checks the expression cache.
# The client will load the cache data by itself.
Expand Down Expand Up @@ -854,7 +862,7 @@ def build_index_from_data(data, start_index=0):
index_data += start_index
return index_data

def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=None):
"""gen_dataset_cache

.. note:: This function does not consider the cache read write lock. Please
Expand All @@ -872,6 +880,9 @@ def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, f
1999-11-10 00:00:00 0 1
1999-11-11 00:00:00 1 2
1999-11-12 00:00:00 2 3
"""
if inst_processors is None:
inst_processors = []
...

.. note:: The start is closed. The end is open!!!!!
Expand Down Expand Up @@ -1076,15 +1087,19 @@ def __init__(self, provider):
f"modify the cache directory via the local_cache_path in the config"
)

def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
if inst_processors is None:
inst_processors = []
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
return hash_args(
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
)

def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
):
if inst_processors is None:
inst_processors = []
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
Expand Down Expand Up @@ -1118,12 +1133,16 @@ def _dataset(
class DatasetURICache(DatasetCache):
"""Prepared cache mechanism for server."""

def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
if inst_processors is None:
inst_processors = []
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)

def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
):
if inst_processors is None:
inst_processors = []
if "local" in C.dataset_provider.lower():
# use LocalDatasetProvider
return self.provider.dataset(
Expand Down
40 changes: 30 additions & 10 deletions qlib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ class DatasetProvider(abc.ABC):
"""

@abc.abstractmethod
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=None):
"""Get dataset data.

Parameters
Expand All @@ -473,6 +473,8 @@ def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day
pd.DataFrame
a pandas dataframe with <instrument, datetime> index.
"""
if inst_processors is None:
inst_processors = []
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")

def _uri(
Expand All @@ -483,7 +485,7 @@ def _uri(
end_time=None,
freq="day",
disk_cache=1,
inst_processors=[],
inst_processors=None,
**kwargs,
):
"""Get task uri, used when generating rabbitmq task in qlib_server
Expand All @@ -504,6 +506,8 @@ def _uri(
whether to skip(0)/use(1)/replace(2) disk_cache.

"""
if inst_processors is None:
inst_processors = []
# TODO: qlib-server support inst_processors
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)

Expand Down Expand Up @@ -545,12 +549,14 @@ def parse_fields(fields):
return [ExpressionD.get_expression_instance(f) for f in fields]

@staticmethod
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=None):
"""
Load and process the data, return the data set.
- default using multi-kernel method.

"""
if inst_processors is None:
inst_processors = []
normalize_column_names = normalize_cache_fields(column_names)
# One process for one task, so that the memory will be freed quicker.
workers = max(min(C.get_kernels(freq), len(instruments_d)), 1)
Expand Down Expand Up @@ -597,14 +603,16 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i
return data

@staticmethod
def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]):
def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=None):
"""
Calculate the expressions for **one** instrument, return a df result.
If the expression has been calculated before, load from cache.

return value: A data frame with index 'datetime' and other data columns.

"""
if inst_processors is None:
inst_processors = []
# FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods
# NOTE: This place is compatible with windows, windows multi-process is spawn
C.register_from_C(g_config)
Expand Down Expand Up @@ -640,7 +648,9 @@ class LocalCalendarProvider(CalendarProvider, ProviderBackendMixin):
Provide calendar data from local data source.
"""

def __init__(self, remote=False, backend={}):
def __init__(self, remote=False, backend=None):
if backend is None:
backend = {}
super().__init__()
self.remote = remote
self.backend = backend
Expand Down Expand Up @@ -681,7 +691,9 @@ class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):
Provide instrument data from local data source.
"""

def __init__(self, backend={}) -> None:
def __init__(self, backend=None) -> None:
if backend is None:
backend = {}
super().__init__()
self.backend = backend

Expand Down Expand Up @@ -729,7 +741,9 @@ class LocalFeatureProvider(FeatureProvider, ProviderBackendMixin):
Provide feature data from local data source.
"""

def __init__(self, remote=False, backend={}):
def __init__(self, remote=False, backend=None):
if backend is None:
backend = {}
super().__init__()
self.remote = remote
self.backend = backend
Expand Down Expand Up @@ -906,8 +920,10 @@ def dataset(
start_time=None,
end_time=None,
freq="day",
inst_processors=[],
inst_processors=None,
):
if inst_processors is None:
inst_processors = []
instruments_d = self.get_instruments_d(instruments, freq)
column_names = self.get_column_names(fields)
if self.align_time:
Expand Down Expand Up @@ -1046,8 +1062,10 @@ def dataset(
freq="day",
disk_cache=0,
return_uri=False,
inst_processors=[],
inst_processors=None,
):
if inst_processors is None:
inst_processors = []
if Inst.get_inst_type(instruments) == Inst.DICT:
get_module_logger("data").warning(
"Getting features from a dict of instruments is not recommended because the features will not be "
Expand Down Expand Up @@ -1167,7 +1185,7 @@ def features(
end_time=None,
freq="day",
disk_cache=None,
inst_processors=[],
inst_processors=None,
):
"""
Parameters
Expand All @@ -1180,6 +1198,8 @@ def features(
and will use provider method if a type error is raised because the DatasetD instance
is a provider class.
"""
if inst_processors is None:
inst_processors = []
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
fields = list(fields) # In case of tuple.
try:
Expand Down
8 changes: 6 additions & 2 deletions qlib/data/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def is_for_infer(self) -> bool:


class DropCol(Processor):
def __init__(self, col_list=[]):
def __init__(self, col_list=None):
if col_list is None:
col_list = []
self.col_list = col_list

def __call__(self, df):
Expand All @@ -127,7 +129,9 @@ def readonly(self):


class FilterCol(Processor):
def __init__(self, fields_group="feature", col_list=[]):
def __init__(self, fields_group="feature", col_list=None):
if col_list is None:
col_list = []
self.fields_group = fields_group
self.col_list = col_list

Expand Down
12 changes: 9 additions & 3 deletions qlib/utils/index_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,19 +528,25 @@ def values(self):

class SingleData(IndexData):
def __init__(
self, data: Union[int, float, np.number, list, dict, pd.Series] = [], index: Union[List, pd.Index, Index] = []
self,
data: Union[int, float, np.number, list, dict, pd.Series] | None = None,
index: Union[List, pd.Index, Index] | None = None,
):
"""A data structure of index and numpy data.
It's used to replace pd.Series due to high-speed.

Parameters
----------
data : Union[int, float, np.number, list, dict, pd.Series]
data : Union[int, float, np.number, list, dict, pd.Series], optional
the input data
index : Union[list, pd.Index]
index : Union[list, pd.Index], optional
the index of data.
empty list indicates that auto filling the index to the length of data
"""
if data is None:
data = []
if index is None:
index = []
# for special data type
if isinstance(data, dict):
assert len(index) == 0
Expand Down