Skip to content

Commit 12fa1cb

Browse files
committed
Update API
1 parent 3352ff9 commit 12fa1cb

File tree

8 files changed

+215
-55
lines changed

8 files changed

+215
-55
lines changed

src/algorithms/epsilon_greedy_q_estimator.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from src.utils.mixins import WithEstimatorMixin
1010
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonGreedyConfig
11+
from src.exceptions.exceptions import InvalidParamValue
1112

1213
StateActionVec = TypeVar('StateActionVec')
1314
State = TypeVar('State')
@@ -42,6 +43,18 @@ def __init__(self, config: EpsilonGreedyQEstimatorConfig):
4243
self.gamma: float = config.gamma
4344
self.env: Env = config.env
4445
self.weights: np.array = None
46+
self.initialize()
47+
48+
def initialize(self) -> None:
49+
"""Initialize the underlying weights
50+
51+
Returns
52+
-------
53+
54+
None
55+
56+
"""
57+
self.weights: np.array = np.zeros((self.env.n_states * self.env.n_actions))
4558

4659
def q_hat_value(self, state_action_vec: StateActionVec) -> float:
4760
"""Returns the
@@ -60,6 +73,10 @@ def q_hat_value(self, state_action_vec: StateActionVec) -> float:
6073
6174
6275
"""
76+
77+
if self.weights is None:
78+
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")
79+
6380
return self.weights.dot(state_action_vec)
6481

6582
def update_weights(self, total_reward: float, state_action: Action,
@@ -81,6 +98,10 @@ def update_weights(self, total_reward: float, state_action: Action,
8198
None
8299
83100
"""
101+
102+
if self.weights is None:
103+
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")
104+
84105
v1 = self.q_hat_value(state_action_vec=state_action)
85106
v2 = self.q_hat_value(state_action_vec=state_action_)
86107
self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action
@@ -99,14 +120,18 @@ def on_state(self, state: State) -> Action:
99120
An environment specific Action type
100121
"""
101122

102-
# compute the state values related to
103-
# the given state
123+
# get the approximation of the q-values
124+
# given the state
125+
104126
q_values = []
105127

106-
for action in range(self.env.n_actions):
107-
state_action_vector = self.env.get_state_action_tile(action=action, state=state)
108-
q_values.append(state_action_vector)
128+
for a in range(self.env.n_actions):
129+
tiled_vector = self.env.featurize_state_action(action=a, state=state)
130+
q_values.append(self.q_hat_value(tiled_vector))
109131

110132
# choose an action at the current state
111133
action = self.eps_policy(q_values, state)
134+
135+
# this is an integer get the ActionBase instead
136+
action = self.env.get_action(action)
112137
return action

src/algorithms/semi_gradient_sarsa.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase, WithEstimatorMixin
1313
from src.utils.episode_info import EpisodeInfo
14+
from src.spaces.time_step import TimeStep
1415
from src.exceptions.exceptions import InvalidParamValue
1516

17+
1618
Policy = TypeVar('Policy')
1719
Env = TypeVar('Env')
1820
State = TypeVar('State')
@@ -38,11 +40,16 @@ class SemiGradSARSA(object):
3840
def __init__(self, config: SemiGradSARSAConfig) -> None:
3941
self.config: SemiGradSARSAConfig = config
4042

43+
@property
44+
def name(self) -> str:
45+
return "Semi-Grad SARSA"
46+
4147
def actions_before_training(self, env: Env, **options) -> None:
4248
"""Specify any actions necessary before training begins
4349
4450
Parameters
4551
----------
52+
4653
env: The environment to train on
4754
options: Any key-value options passed by the client
4855
@@ -60,27 +67,74 @@ def actions_before_training(self, env: Env, **options) -> None:
6067
self.q_table[state, action] = 0.0
6168
"""
6269

63-
def on_episode(self, env: Env, **options) -> EpisodeInfo:
70+
def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None:
71+
"""Any actions to perform before the episode begins
72+
73+
Parameters
74+
----------
75+
76+
env: The instance of the training environment
77+
episode_idx: The training episode index
78+
options: Any keyword options passed by the client code
79+
80+
Returns
81+
-------
82+
83+
None
84+
85+
"""
86+
87+
def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> None:
88+
"""Any actions after the training episode ends
89+
90+
Parameters
91+
----------
92+
93+
env: The training environment
94+
episode_idx: The training episode index
95+
options: Any options passed by the client code
96+
97+
Returns
98+
-------
99+
100+
None
101+
"""
102+
103+
def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
104+
"""Train the algorithm on the episode
105+
106+
Parameters
107+
----------
108+
109+
env: The environment to train on
110+
options: Any keyword based options passed by the client code
111+
112+
Returns
113+
-------
114+
115+
An instance of EpisodeInfo
116+
"""
64117

65118
episode_reward = 0.0
66119
episode_n_itrs = 0
67120

68121
# reset the environment
69-
time_step = env.reset()
122+
time_step = env.reset(**{"tiled_state": False})
70123

71124
# select a state
72125
state: State = time_step.observation
73126

74127
#choose an action using the policy
75-
action: Action = self.config.policy(state)
128+
action: Action = self.config.policy.on_state(state)
76129

77130
for itr in range(self.config.n_itrs_per_episode):
78131

79132
# take action and observe reward and next_state
80-
time_step = env.step(action)
81-
reward: float = 0.0
133+
time_step: TimeStep = env.step(action, **{"tiled_state": False})
134+
135+
reward: float = time_step.reward
82136
episode_reward += reward
83-
next_state: State = None
137+
next_state: State = time_step.observation
84138

85139
# if next_state is terminal i.e. the done flag
86140
# is set. then update the weights
@@ -109,6 +163,7 @@ def _weights_update_episode_done(self, state: State, reward: float,
109163
110164
Parameters
111165
----------
166+
112167
state: The current state
113168
reward: The reward to use
114169
action: The action we took at state

src/algorithms/trainer.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,41 @@ def actions_before_training(self) -> None:
6060
self.iterations_per_episode = []
6161
self.agent.actions_before_training(self.env)
6262

63-
def actions_before_episode_begins(self, **options) -> None:
63+
def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None:
6464
"""Perform any actions necessary before the training begins
6565
6666
Parameters
6767
----------
68+
env: The environment to train on
69+
episode_idx: The training episode index
6870
options: Any options passed by the client code
6971
7072
Returns
7173
-------
7274
7375
None
76+
7477
"""
75-
self.agent.actions_before_episode_begins(**options)
78+
self.agent.actions_before_episode_begins(env, episode_idx, **options)
79+
80+
def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> None:
81+
"""Any actions after the training episode ends
82+
83+
Parameters
84+
----------
85+
86+
env: The environment to train on
87+
episode_idx: The training episode index
88+
options: Any options passed by the client code
89+
90+
Returns
91+
-------
7692
77-
def actions_after_episode_ends(self, **options):
78-
self.agent.actions_after_episode_ends(**options)
93+
None
94+
"""
95+
self.agent.actions_after_episode_ends(env, episode_idx, **options)
7996

80-
if options["episode_idx"] % self.configuration['output_msg_frequency'] == 0:
97+
if episode_idx % self.configuration['output_msg_frequency'] == 0:
8198
if self.env.config.distorted_set_path is not None:
8299
self.env.save_current_dataset(options["episode_idx"])
83100

@@ -93,10 +110,10 @@ def train(self):
93110
# reset the environment
94111
#ignore = self.env.reset()
95112

96-
self.actions_before_episode_begins(**{"env": self.env})
113+
self.actions_before_episode_begins(self.env, episode,)
97114
# train for a number of iterations
98115
#episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)
99-
episode_info: EpisodeInfo = self.agent.on_episode(self.env)
116+
episode_info: EpisodeInfo = self.agent.on_episode(self.env, episode)
100117

101118
print("{0} Episode score={1}, episode total avg distortion {2}".format(INFO, episode_info.episode_score,
102119
episode_info.total_distortion / episode_info.info["n_iterations"]))
@@ -107,6 +124,6 @@ def train(self):
107124
self.iterations_per_episode.append(episode_info.info["n_iterations"])
108125
self.total_rewards[episode] = episode_info.episode_score
109126
self.total_distortions.append(episode_info.total_distortion)
110-
self.actions_after_episode_ends(**{"episode_idx": episode})
127+
self.actions_after_episode_ends(self.env, episode, **{})
111128

112129
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/examples/qlearning_three_columns.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def get_ethinicity_hierarchy():
8888
OUTPUT_MSG_FREQUENCY = 100
8989
N_ROUNDS_BELOW_MIN_DISTORTION = 10
9090
SAVE_DISTORTED_SETS_DIR = "/home/alex/qi3/drl_anonymity/src/examples/q_learn_distorted_sets/distorted_set"
91+
REWARD_FACTOR = 0.95
92+
PUNISH_FACTOR = 2.0
9193

9294
# specify the columns to drop
9395
drop_columns = MockSubjectsLoader.FEATURES_DROP_NAMES + ["preventative_treatment", "gender",
@@ -144,8 +146,8 @@ def get_ethinicity_hierarchy():
144146
numeric_column_distortion_metric_type=NumericDistanceType.L2_AVG,
145147
string_column_distortion_metric_type=StringDistanceType.COSINE_NORMALIZE,
146148
dataset_distortion_type=DistortionCalculationType.SUM)
147-
env_config.reward_factor = 0.95
148-
env_config.punish_factor = 2.0
149+
env_config.reward_factor = REWARD_FACTOR #0.95
150+
env_config.punish_factor = PUNISH_FACTOR #2.0
149151

150152
# create the environment
151153
env = DiscreteStateEnvironment(env_config=env_config)

src/policies/epsilon_greedy_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, eps: float, n_actions: int,
6666
self.user_defined_decrease_method: UserDefinedDecreaseMethod = user_defined_decrease_method
6767

6868
def __str__(self) -> str:
69-
return self.__name__
69+
return "EpsilonGreedyPolicy"
7070

7171
def __call__(self, q_table: QTable, state: Any) -> int:
7272
"""

src/spaces/discrete_state_environment.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def from_options(cls, *, data_set: DataSet, action_space: ActionSpace,
6666
return cls(env_config=config)
6767

6868
@classmethod
69-
def from_dataset(cls, data_set: DataSet, *, action_space: ActionSpace=None,
69+
def from_dataset(cls, data_set: DataSet, *, action_space: ActionSpace = None,
7070
reward_manager: RewardManager = None, distortion_calculator: DistortionCalculator = None):
7171

7272
config = DiscreteEnvConfig(data_set=data_set, action_space=action_space, reward_manager=reward_manager,
@@ -115,6 +115,19 @@ def column_distortions(self) -> dict:
115115
return self.column_distances
116116

117117
def get_action(self, aidx: int) -> ActionBase:
118+
"""Returns the action if the global aidx index
119+
120+
Parameters
121+
----------
122+
123+
aidx: The index of the action to return
124+
125+
Returns
126+
-------
127+
128+
An instance of ActionBase
129+
130+
"""
118131
return self.config.action_space[aidx]
119132

120133
def save_current_dataset(self, episode_index: int, save_index: bool = False) -> None:
@@ -257,6 +270,7 @@ def step(self, action: ActionBase) -> TimeStep:
257270
"""
258271
# apply the action and update distoration
259272
# and column count
273+
260274
self.apply_action(action=action)
261275

262276
# calculate the distortion of the dataset

0 commit comments

Comments
 (0)