Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion qlib/contrib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

class ArcticFeatureProvider(FeatureProvider):
def __init__(
self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")]
self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=None
):
if market_transaction_time_list is None:
market_transaction_time_list = [("09:15", "11:30"), ("13:00", "15:00")]
super().__init__()
self.uri = uri
# TODO:
Expand Down
16 changes: 12 additions & 4 deletions qlib/contrib/data/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ def __init__(
start_time=None,
end_time=None,
freq="day",
infer_processors=_DEFAULT_INFER_PROCESSORS,
learn_processors=_DEFAULT_LEARN_PROCESSORS,
infer_processors=None,
learn_processors=None,
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
inst_processors=None,
**kwargs,
):
if infer_processors is None:
infer_processors = _DEFAULT_INFER_PROCESSORS
if learn_processors is None:
learn_processors = _DEFAULT_LEARN_PROCESSORS
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)

Expand Down Expand Up @@ -102,15 +106,19 @@ def __init__(
start_time=None,
end_time=None,
freq="day",
infer_processors=[],
learn_processors=_DEFAULT_LEARN_PROCESSORS,
infer_processors=None,
learn_processors=None,
fit_start_time=None,
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processors=None,
**kwargs,
):
if infer_processors is None:
infer_processors = []
if learn_processors is None:
learn_processors = _DEFAULT_LEARN_PROCESSORS
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)

Expand Down
24 changes: 18 additions & 6 deletions qlib/contrib/data/highfreq_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ def __init__(
instruments="csi300",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
infer_processors=None,
learn_processors=None,
fit_start_time=None,
fit_end_time=None,
drop_raw=True,
):
if infer_processors is None:
infer_processors = []
if learn_processors is None:
learn_processors = []
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)

Expand Down Expand Up @@ -106,8 +110,8 @@ def __init__(
instruments="csi300",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
infer_processors=None,
learn_processors=None,
fit_start_time=None,
fit_end_time=None,
drop_raw=True,
Expand All @@ -116,6 +120,10 @@ def __init__(
columns=["$open", "$high", "$low", "$close", "$vwap"],
inst_processors=None,
):
if infer_processors is None:
infer_processors = []
if learn_processors is None:
learn_processors = []
self.day_length = day_length
self.columns = columns

Expand Down Expand Up @@ -310,13 +318,17 @@ def __init__(
instruments="csi300",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
infer_processors=None,
learn_processors=None,
fit_start_time=None,
fit_end_time=None,
inst_processors=None,
drop_raw=True,
):
if infer_processors is None:
infer_processors = []
if learn_processors is None:
learn_processors = []
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)

Expand Down
8 changes: 6 additions & 2 deletions qlib/contrib/data/highfreq_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def _prepare_calender_cache(self):
Cal.calendar(freq=self.freq)
get_calendar_day(freq=self.freq)

def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
def _gen_dataframe(self, config, datasets=None):
if datasets is None:
datasets = ["train", "valid", "test"]
try:
path = config.pop("path")
except KeyError as e:
Expand Down Expand Up @@ -163,7 +165,9 @@ def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res

def _gen_data(self, config, datasets=["train", "valid", "test"]):
def _gen_data(self, config, datasets=None):
if datasets is None:
datasets = ["train", "valid", "test"]
try:
path = config.pop("path")
except KeyError as e:
Expand Down
20 changes: 10 additions & 10 deletions qlib/contrib/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,7 @@ def __init__(self, config=None, **kwargs):
super().__init__(config=_config, **kwargs)

@staticmethod
def get_feature_config(
config={
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
):
def get_feature_config(config=None):
"""create factors from config

config = {
Expand All @@ -99,6 +90,15 @@ def get_feature_config(
}
}
"""
if config is None:
config = {
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
fields = []
names = []
if "kbar" in config:
Expand Down
4 changes: 3 additions & 1 deletion qlib/contrib/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def long_short_backtest(
trade_unit=None,
limit_threshold=None,
min_cost=5,
subscribe_fields=[],
subscribe_fields=None,
extract_codes=False,
):
"""
Expand All @@ -307,6 +307,8 @@ def long_short_backtest(
"short": short_returns(excess),
"long_short": long_short_returns}
"""
if subscribe_fields is None:
subscribe_fields = []
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()

Expand Down
4 changes: 3 additions & 1 deletion qlib/contrib/model/pytorch_adarnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,16 @@ def __init__(
use_bottleneck=False,
bottleneck_width=256,
n_input=128,
n_hiddens=[64, 64],
n_hiddens=None,
n_output=6,
dropout=0.0,
len_seq=9,
model_type="AdaRNN",
trans_loss="mmd",
GPU=0,
):
if n_hiddens is None:
n_hiddens = [64, 64]
super(AdaRNN, self).__init__()
self.use_bottleneck = use_bottleneck
self.n_input = n_input
Expand Down
14 changes: 8 additions & 6 deletions qlib/contrib/model/pytorch_general_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def __init__(
GPU=0,
seed=None,
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
pt_model_kwargs={
"d_feat": 6,
"hidden_size": 64,
"num_layers": 2,
"dropout": 0.0,
},
pt_model_kwargs=None,
):
if pt_model_kwargs is None:
pt_model_kwargs = {
"d_feat": 6,
"hidden_size": 64,
"num_layers": 2,
"dropout": 0.0,
}
# Set logger.
self.logger = get_module_logger("GeneralPTNN")
self.logger.info("GeneralPTNN pytorch version...")
Expand Down
10 changes: 6 additions & 4 deletions qlib/contrib/model/pytorch_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def __init__(
init_model=None,
eval_train_metric=False,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
pt_model_kwargs={
"input_dim": 360,
"layers": (256,),
},
pt_model_kwargs=None,
valid_key=DataHandlerLP.DK_L,
# TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing
):
if pt_model_kwargs is None:
pt_model_kwargs = {
"input_dim": 360,
"layers": (256,),
}
# Set logger.
self.logger = get_module_logger("DNNModelPytorch")
self.logger.info("DNN pytorch version...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def model_performance_graph(
N: int = 5,
reverse=False,
rank=False,
graph_names: list = ["group_return", "pred_ic", "pred_autocorr"],
graph_names: list = None,
show_notebook: bool = True,
show_nature_day: bool = False,
**kwargs,
Expand Down Expand Up @@ -328,6 +328,8 @@ def model_performance_graph(
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
"""
if graph_names is None:
graph_names = ["group_return", "pred_ic", "pred_autocorr"]
figure_list = []
for graph_name in graph_names:
fun_res = eval(f"_{graph_name}")(
Expand Down
6 changes: 3 additions & 3 deletions qlib/contrib/strategy/optimizer/enhanced_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
f_dev: Optional[Union[List[float], np.ndarray]] = None,
scale_return: bool = True,
epsilon: float = 5e-5,
solver_kwargs: Optional[Dict[str, Any]] = {},
solver_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Args:
Expand All @@ -63,8 +63,8 @@ def __init__(
epsilon (float): minimum weight
solver_kwargs (dict): kwargs for cvxpy solver
"""

assert lamb >= 0, "risk aversion parameter `lamb` should be positive"
if solver_kwargs is None:
solver_kwargs = {}
self.lamb = lamb

assert delta >= 0, "turnover limit `delta` should be positive"
Expand Down
8 changes: 6 additions & 2 deletions qlib/contrib/strategy/signal_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,15 @@ def __init__(
riskmodel_root,
market="csi500",
turn_limit=None,
name_mapping={},
optimizer_kwargs={},
name_mapping=None,
optimizer_kwargs=None,
verbose=False,
**kwargs,
):
if name_mapping is None:
name_mapping = {}
if optimizer_kwargs is None:
optimizer_kwargs = {}
super().__init__(**kwargs)

self.logger = get_module_logger("EnhancedIndexingStrategy")
Expand Down
4 changes: 3 additions & 1 deletion qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self,
handler: Union[Dict, DataHandler],
segments: Dict[Text, Tuple],
fetch_kwargs: Dict = {},
fetch_kwargs: Dict = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -116,6 +116,8 @@ def __init__(
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
if fetch_kwargs is None:
fetch_kwargs = {}
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
self.fetch_kwargs = copy(fetch_kwargs)
Expand Down
13 changes: 10 additions & 3 deletions qlib/data/dataset/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,9 @@ def __init__(
start_time=None,
end_time=None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors: List = [],
learn_processors: List = [],
shared_processors: List = [],
infer_processors: List = None,
learn_processors: List = None,
shared_processors: List = None,
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
Expand Down Expand Up @@ -489,6 +489,13 @@ def __init__(
Whether to drop the raw data
"""

if infer_processors is None:
infer_processors = []
if learn_processors is None:
learn_processors = []
if shared_processors is None:
shared_processors = []

# Setup preprocessor
self.infer_processors = [] # for lint
self.learn_processors = [] # for lint
Expand Down
5 changes: 4 additions & 1 deletion qlib/data/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class DataLoaderDH(DataLoader):
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""

def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
def __init__(self, handler_config: dict, fetch_kwargs: dict = None, is_group=False):
"""
Parameters
----------
Expand All @@ -386,6 +386,9 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False
"""
from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415

if fetch_kwargs is None:
fetch_kwargs = {}

if is_group:
self.handlers = {
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
Expand Down