Skip to content

Commit 3425e22

Browse files
committed
Fix tiling
1 parent 11ef082 commit 3425e22

File tree

8 files changed

+456
-86
lines changed

8 files changed

+456
-86
lines changed

src/datasets/datasets_loaders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class MockSubjectsData(object):
1616

1717
# Path to the dataset file
18-
FILENAME: Path = Path("../../data/mocksubjects.csv")
18+
FILENAME: Path = Path("/home/alex/qi3/drl_anonymity/data/mocksubjects.csv") #("../../data/mocksubjects.csv")
1919

2020
# the assumed column types. We use this map to cast
2121
# the types of the columns
@@ -57,7 +57,7 @@ def from_options(cls, *, filename: Path,
5757
NORMALIZED_COLUMNS=column_normalization)
5858
return cls(data=data)
5959

60-
def __init__(self, data: MockSubjectsData, do_read: bool=True):
60+
def __init__(self, data: MockSubjectsData, do_read: bool = True):
6161
super(MockSubjectsLoader, self).__init__(columns=data.COLUMNS_TYPES)
6262

6363
if do_read:

src/spaces/action_space.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010

1111

1212
class ActionSpace(Discrete):
13+
"""ActionSpace class models a discrete action space of size n
1314
"""
14-
ActionSpace class models a discrete action space of size n
15-
"""
15+
16+
@classmethod
17+
def from_actions(cls, *actions: ActionBase):
18+
19+
space = cls(n=len(actions))
20+
space.add_many(*actions)
21+
return space
1622

1723
def __init__(self, n: int) -> None:
1824

src/spaces/discrete_state_environment.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ def from_options(cls, *, data_set: DataSet, action_space: ActionSpace,
6262
punish_factor=punish_factor, max_distortion=max_distortion, gamma=gamma,
6363
n_states=n_states, min_distortion=min_distortion,
6464
average_distortion_constraint=average_distortion_constraint)
65+
66+
return cls(env_config=config)
67+
68+
@classmethod
69+
def from_dataset(cls, data_set: DataSet, *, action_space: ActionSpace=None,
70+
reward_manager: RewardManager = None, distortion_calculator: DistortionCalculator = None):
71+
72+
config = DiscreteEnvConfig(data_set=data_set, action_space=action_space, reward_manager=reward_manager,
73+
distortion_calculator=distortion_calculator)
6574
return cls(env_config=config)
6675

6776
def __init__(self, env_config: DiscreteEnvConfig) -> None:

src/spaces/state.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
"""
66

7+
import numpy as np
78
from typing import TypeVar, List, Any
89
from src.exceptions.exceptions import Error
910

@@ -65,12 +66,12 @@ def __init__(self):
6566
self.column_distortions = {}
6667

6768
def __contains__(self, column_name: str) -> bool:
68-
"""
69-
Returns true if column_name is in the column_distortions
69+
"""Returns true if column_name is in the column_distortions
7070
keys
7171
7272
Parameters
7373
----------
74+
7475
column_name: The column name to query
7576
7677
Returns
@@ -85,11 +86,11 @@ def __iter__(self):
8586
return StateIterator(list(self.column_distortions.keys()))
8687

8788
def __getitem__(self, name: str) -> float:
88-
"""
89-
Get the distortion corresponding to the name-th column
89+
"""Get the distortion corresponding to the name-th column
9090
9191
Parameters
9292
----------
93+
9394
name: The name of the column
9495
9596
Returns
@@ -99,4 +100,15 @@ def __getitem__(self, name: str) -> float:
99100
"""
100101
return self.column_distortions[name]
101102

103+
def to_numpy(self) -> np.array:
104+
"""Returns the self.column_distortions values as numpy array
105+
106+
Returns
107+
-------
108+
np.array
109+
110+
"""
111+
112+
vals = self.column_distortions.values()
113+
return np.array(vals)
102114

0 commit comments

Comments
 (0)