Skip to content

Commit 7dbb2dc

Browse files
committed
#27 API updates
1 parent 8cd2fb5 commit 7dbb2dc

File tree

7 files changed

+113
-55
lines changed

7 files changed

+113
-55
lines changed

src/algorithms/q_learning.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from typing import TypeVar
77

88
from src.exceptions.exceptions import InvalidParamValue
9-
from src.utils.mixins import WithMaxActionMixin
9+
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase
1010

1111
Env = TypeVar('Env')
1212
Policy = TypeVar('Policy')
13+
Criterion = TypeVar('Criterion')
1314

1415

1516
class QLearnConfig(object):
@@ -39,8 +40,8 @@ def name(self) -> str:
3940

4041
def actions_before_training(self, env: Env, **options):
4142

42-
if self.config.policy is None:
43-
raise InvalidParamValue(param_name="policy", param_value="None")
43+
if not isinstance(self.config.policy, WithQTableMixinBase):
44+
raise InvalidParamValue(param_name="policy", param_value=str(self.config.policy))
4445

4546
for state in range(1, env.n_states):
4647
for action in range(env.n_actions):
@@ -56,10 +57,11 @@ def actions_after_episode_ends(self, **options):
5657

5758
self.config.policy.actions_after_episode(options['episode_idx'])
5859

59-
def play(self, env: Env) -> None:
60+
def play(self, env: Env, stop_criterion: Criterion) -> None:
6061
"""
6162
Play the game on the environment. This should produce
6263
a distorted dataset
64+
:param stop_criterion:
6365
:param env:
6466
:return:
6567
"""
@@ -69,7 +71,23 @@ def play(self, env: Env) -> None:
6971
# the max payout.
7072
# TODO: This will no work as the distortion is calculated
7173
# by summing over the columns.
72-
raise NotImplementedError("Function not implemented")
74+
75+
# set the q_table for the policy
76+
self.config.policy.q_table = self.q_table
77+
total_dist = env.total_average_current_distortion()
78+
while stop_criterion.continue_itr(total_dist):
79+
80+
if stop_criterion.iteration_counter == 12:
81+
print("Break...")
82+
83+
# use the policy to select an action
84+
state_idx = env.get_aggregated_state(total_dist)
85+
action_idx = self.config.policy.on_state(state_idx)
86+
action = env.get_action(action_idx)
87+
print("{0} At state={1} with distortion={2} select action={3}".format("INFO: ", state_idx, total_dist,
88+
action.column_name + "-" + action.action_type.name))
89+
env.step(action=action)
90+
total_dist = env.total_average_current_distortion()
7391

7492
def train(self, env: Env, **options) -> tuple:
7593

@@ -84,15 +102,10 @@ def train(self, env: Env, **options) -> tuple:
84102
for itr in range(self.config.n_itrs_per_episode):
85103

86104
# epsilon-greedy action selection
87-
action_idx = self.config.policy(q_func=self.q_table, state=state)
105+
action_idx = self.config.policy(q_table=self.q_table, state=state)
88106

89107
action = env.get_action(action_idx)
90108

91-
#if action.action_type.name == "GENERALIZE" and action.column_name == "salary":
92-
# print("Attempt to generalize salary")
93-
#else:
94-
# print(action.action_type.name, " on ", action.column_name)
95-
96109
# take action A, observe R, S'
97110
next_time_step = env.step(action)
98111
next_state = next_time_step.observation
@@ -111,7 +124,8 @@ def train(self, env: Env, **options) -> tuple:
111124

112125
return episode_score, total_distortion, counter
113126

114-
def _update_Q_table(self, state: int, action: int, n_actions: int, reward: float, next_state: int = None) -> None:
127+
def _update_Q_table(self, state: int, action: int, n_actions: int,
128+
reward: float, next_state: int = None) -> None:
115129
"""
116130
Update the Q-value for the state
117131
"""

src/algorithms/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ def avg_rewards(self) -> np.array:
3232
avg[i] = self.total_rewards[i] / self.iterations_per_episode[i]
3333
return avg
3434

35+
def avg_distortion(self) -> np.array:
36+
"""
37+
Returns the average reward per episode
38+
:return:
39+
"""
40+
avg = np.zeros(self.configuration['n_episodes'])
41+
42+
for i in range(len(self.total_distortions)):
43+
avg[i] = self.total_distortions[i] / self.iterations_per_episode[i]
44+
return avg
45+
3546
def actions_before_training(self):
3647
"""
3748
Any actions to perform before training begins

src/policies/epsilon_greedy_policy.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from enum import Enum
77
from typing import Any, TypeVar
88

9-
109
from src.utils.mixins import WithMaxActionMixin
1110

1211
UserDefinedDecreaseMethod = TypeVar('UserDefinedDecreaseMethod')
@@ -42,26 +41,38 @@ def __init__(self, env: Env, eps: float,
4241
self._epsilon_decay_factor = epsilon_decay_factor
4342
self.user_defined_decrease_method: UserDefinedDecreaseMethod = user_defined_decrease_method
4443

45-
def __call__(self, q_func: QTable, state: Any) -> int:
44+
def __str__(self) -> str:
45+
return self.__name__
46+
47+
def __call__(self, q_table: QTable, state: Any) -> int:
4648
"""
4749
Execute the policy
4850
:param q_func:
4951
:param state:
5052
:return:
5153
"""
5254

55+
# update the store q_table
56+
self.q_table = q_table
57+
5358
# select greedy action with probability epsilon
5459
if random.random() > self._eps:
55-
self.q_table = q_func
5660
return self.max_action(state=state, n_actions=self._n_actions)
57-
5861
else:
5962

6063
# otherwise, select an action randomly
6164
# what happens if we select an action that
6265
# has exhausted it's transforms?
6366
return random.choice(np.arange(self._n_actions))
6467

68+
def on_state(self, state: Any) -> int:
69+
"""
70+
Returns the optimal action on the current state
71+
:param state:
72+
:return:
73+
"""
74+
return self.max_action(state=state, n_actions=self._n_actions)
75+
6576
def actions_after_episode(self, episode_idx: int, **options) -> None:
6677
"""
6778
Apply actions on the policy after the end of the episode

src/spaces/discrete_state_environment.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def get_action(self, aidx: int) -> ActionBase:
129129
return self.config.action_space[aidx]
130130

131131
def save_current_dataset(self, episode_index: int) -> None:
132+
"""
133+
Save the current distorted datase for the given episode index
134+
:param episode_index:
135+
:return:
136+
"""
132137
self.distorted_data_set.save_to_csv(
133138
filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)))
134139

@@ -192,7 +197,8 @@ def apply_action(self, action: ActionBase):
192197
return
193198

194199
# apply the transform of the data set
195-
self.distorted_data_set.apply_column_transform(column_name=action.column_name, transform=action)
200+
self.distorted_data_set.apply_column_transform(column_name=action.column_name,
201+
transform=action)
196202

197203
# what is the previous and current values for the column
198204
current_column = self.distorted_data_set.get_column(col_name=action.column_name)
@@ -205,14 +211,8 @@ def apply_action(self, action: ActionBase):
205211
start_column = "".join(start_column.values)
206212
datatype = 'str'
207213

208-
# join the column to calculate the distance
209-
# distance = self.string_distance_calculator.calculate(txt1="".join(current_column.values),
210-
# txt2="".join(start_column.values))
211-
# else:
212-
# distance = NumericDistanceCalculator(dist_type=self.config.numeric_column_distortion_metric_type)\
213-
# .calculate(state1=current_column, state2=start_column)
214-
215-
distance = self.config.distortion_calculator.calculate(current_column, start_column, datatype)
214+
distance = self.config.distortion_calculator.calculate(current_column,
215+
start_column, datatype)
216216

217217
self.column_distances[action.column_name] = distance
218218

@@ -312,35 +312,38 @@ def step(self, action: ActionBase) -> TimeStep:
312312

313313
# TODO: these modifications will cause the agent to always
314314
# move close to transition points
315-
if next_state < min_dist_bin <= self.current_time_step.observation:
316-
# the agent chose to step into the chaos again
317-
# we punish him with double the reward
318-
reward = 2.0 * self.config.reward_manager.out_of_min_bound_reward
319-
elif next_state > max_dist_bin >= self.current_time_step.observation:
320-
# the agent is going to chaos from above
321-
# punish him
322-
reward = 2.0 * self.config.reward_manager.out_of_max_bound_reward
323-
324-
elif next_state >= min_dist_bin > self.current_time_step.observation:
325-
# the agent goes towards the transition of min point so give a higher reward
326-
# for this
327-
reward = 0.95 * self.config.reward_manager.in_bounds_reward
328-
329-
elif next_state <= max_dist_bin < self.current_time_step.observation:
330-
# the agent goes towards the transition of max point so give a higher reward
331-
# for this
332-
reward = 0.95 * self.config.reward_manager.in_bounds_reward
333-
334-
if next_state >= self.n_states:
315+
if next_state is not None and self.current_time_step.observation is not None:
316+
if next_state < min_dist_bin <= self.current_time_step.observation:
317+
# the agent chose to step into the chaos again
318+
# we punish him with double the reward
319+
reward = 2.0 * self.config.reward_manager.out_of_min_bound_reward
320+
elif next_state > max_dist_bin >= self.current_time_step.observation:
321+
# the agent is going to chaos from above
322+
# punish him
323+
reward = 2.0 * self.config.reward_manager.out_of_max_bound_reward
324+
325+
elif next_state >= min_dist_bin > self.current_time_step.observation:
326+
# the agent goes towards the transition of min point so give a higher reward
327+
# for this
328+
reward = 0.95 * self.config.reward_manager.in_bounds_reward
329+
330+
elif next_state <= max_dist_bin < self.current_time_step.observation:
331+
# the agent goes towards the transition of max point so give a higher reward
332+
# for this
333+
reward = 0.95 * self.config.reward_manager.in_bounds_reward
334+
335+
if next_state is None or next_state >= self.n_states:
335336
done = True
336337

337338
if done:
338339
step_type = StepType.LAST
339340
next_state = None
340341

341-
self.current_time_step = TimeStep(step_type=step_type, reward=reward,
342+
self.current_time_step = TimeStep(step_type=step_type,
343+
reward=reward,
342344
observation=next_state,
343-
discount=self.config.gamma, info={"total_distortion": current_distortion})
345+
discount=self.config.gamma,
346+
info={"total_distortion": current_distortion})
344347

345348
return self.current_time_step
346349

src/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
INFO = "INFO: "
1+
from src.utils.version import VERSION
2+
INFO = "INFO: "

src/utils/mixins.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
"""
44

55
import numpy as np
6+
import abc
67
from typing import TypeVar, Any
78

9+
from src.exceptions.exceptions import InvalidParamValue
10+
811
QTable = TypeVar('QTable')
912
Hierarchy = TypeVar('Hierarchy')
1013

@@ -48,26 +51,42 @@ def finished(self) -> bool:
4851
return exhausted
4952

5053

51-
class WithQTableMixin(object):
54+
class WithQTableMixinBase(metaclass=abc.ABCMeta):
5255
"""
53-
Helper class to associate a q_table with an algorithm
54-
if this is needed.
56+
Base class to impose the concept of Q-table
5557
"""
58+
5659
def __init__(self):
5760
# the table representing the q function
5861
# client code should choose the type of
5962
# the table
6063
self.q_table: QTable = None
6164

6265

63-
class WithMaxActionMixin(object):
66+
class WithQTableMixin(WithQTableMixinBase):
67+
"""
68+
Helper class to associate a q_table with an algorithm
69+
if this is needed.
70+
"""
71+
def __init__(self):
72+
super(WithQTableMixin, self).__init__()
73+
74+
def state_action_values(self, state: Any, n_actions: int):
75+
76+
if self.q_table is None:
77+
raise InvalidParamValue(param_name="q_table", param_value="None")
78+
79+
values = [self.q_table[state, a] for a in range(n_actions)]
80+
return values
81+
82+
83+
class WithMaxActionMixin(WithQTableMixin):
6484
"""
6585
The class WithMaxActionMixin.
6686
"""
6787

6888
def __init__(self):
6989
super(WithMaxActionMixin, self).__init__()
70-
self.q_table: QTable = None
7190

7291
def max_action(self, state: Any, n_actions: int) -> int:
7392
"""
@@ -77,7 +96,7 @@ def max_action(self, state: Any, n_actions: int) -> int:
7796
:param n_actions: Total number of actions allowed
7897
:return: The action that corresponds to the maximum value
7998
"""
80-
values = [self.q_table[state, a] for a in range(n_actions)]
99+
values = self.state_action_values(state, n_actions) #[self.q_table[state, a] for a in range(n_actions)]
81100
values = np.array(values)
82101
action = np.argmax(values)
83102
return int(action)

src/utils/serial_hierarchy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from typing import List, Any
77
from src.utils.hierarchy_base import HierarchyBase
8-
from src.utils.updateable_map import UpdateableMap
98

109

1110
class SerialtHierarchyIterator(object):

0 commit comments

Comments
 (0)