66from typing import TypeVar
77
88from src .exceptions .exceptions import InvalidParamValue
9- from src .utils .mixins import WithMaxActionMixin
9+ from src .utils .mixins import WithMaxActionMixin , WithQTableMixinBase
1010
1111Env = TypeVar ('Env' )
1212Policy = TypeVar ('Policy' )
13+ Criterion = TypeVar ('Criterion' )
1314
1415
1516class QLearnConfig (object ):
@@ -39,8 +40,8 @@ def name(self) -> str:
3940
4041 def actions_before_training (self , env : Env , ** options ):
4142
42- if self .config .policy is None :
43- raise InvalidParamValue (param_name = "policy" , param_value = "None" )
43+ if not isinstance ( self .config .policy , WithQTableMixinBase ) :
44+ raise InvalidParamValue (param_name = "policy" , param_value = str ( self . config . policy ) )
4445
4546 for state in range (1 , env .n_states ):
4647 for action in range (env .n_actions ):
@@ -56,10 +57,11 @@ def actions_after_episode_ends(self, **options):
5657
5758 self .config .policy .actions_after_episode (options ['episode_idx' ])
5859
59- def play (self , env : Env ) -> None :
60+ def play (self , env : Env , stop_criterion : Criterion ) -> None :
6061 """
6162 Play the game on the environment. This should produce
6263 a distorted dataset
64+ :param stop_criterion:
6365 :param env:
6466 :return:
6567 """
@@ -69,7 +71,23 @@ def play(self, env: Env) -> None:
6971 # the max payout.
7072 # TODO: This will no work as the distortion is calculated
7173 # by summing over the columns.
72- raise NotImplementedError ("Function not implemented" )
74+
75+ # set the q_table for the policy
76+ self .config .policy .q_table = self .q_table
77+ total_dist = env .total_average_current_distortion ()
78+ while stop_criterion .continue_itr (total_dist ):
79+
80+ if stop_criterion .iteration_counter == 12 :
81+ print ("Break..." )
82+
83+ # use the policy to select an action
84+ state_idx = env .get_aggregated_state (total_dist )
85+ action_idx = self .config .policy .on_state (state_idx )
86+ action = env .get_action (action_idx )
87+ print ("{0} At state={1} with distortion={2} select action={3}" .format ("INFO: " , state_idx , total_dist ,
88+ action .column_name + "-" + action .action_type .name ))
89+ env .step (action = action )
90+ total_dist = env .total_average_current_distortion ()
7391
7492 def train (self , env : Env , ** options ) -> tuple :
7593
@@ -84,15 +102,10 @@ def train(self, env: Env, **options) -> tuple:
84102 for itr in range (self .config .n_itrs_per_episode ):
85103
86104 # epsilon-greedy action selection
87- action_idx = self .config .policy (q_func = self .q_table , state = state )
105+ action_idx = self .config .policy (q_table = self .q_table , state = state )
88106
89107 action = env .get_action (action_idx )
90108
91- #if action.action_type.name == "GENERALIZE" and action.column_name == "salary":
92- # print("Attempt to generalize salary")
93- #else:
94- # print(action.action_type.name, " on ", action.column_name)
95-
96109 # take action A, observe R, S'
97110 next_time_step = env .step (action )
98111 next_state = next_time_step .observation
@@ -111,7 +124,8 @@ def train(self, env: Env, **options) -> tuple:
111124
112125 return episode_score , total_distortion , counter
113126
114- def _update_Q_table (self , state : int , action : int , n_actions : int , reward : float , next_state : int = None ) -> None :
127+ def _update_Q_table (self , state : int , action : int , n_actions : int ,
128+ reward : float , next_state : int = None ) -> None :
115129 """
116130 Update the Q-value for the state
117131 """
0 commit comments