diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index 9288a93..4f67559 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -112,6 +112,25 @@ def __init__(self, name, root=_DEFAULT_DATADIR, download=True): self.index = self.data['u'] self.lam = self.data['lam'] + + def split(self, split): + tot_len = len(self.neural) + train_idx = np.arange(tot_len)[:int(tot_len*0.8)] + valid_idx = np.arange(tot_len)[int(tot_len*0.8):] + + if split == 'train': + self.neural = self.neural[train_idx] + self.index = self.index[train_idx] + self.idx = train_idx + elif split == 'valid': + self.neural = self.neural[valid_idx] + self.index = self.index[valid_idx] + self.idx = valid_idx + elif split == 'all': + pass + else: + raise ValueError(f"{split} not supported") + @property def input_dimension(self): return self.neural.size(1)