11"""Module epsilon_greedy_q_estimator
22
33"""
4+ from typing import TypeVar
5+ import numpy as np
6+ from dataclasses import dataclass
47
58from src .utils .mixins import WithEstimatorMixin
9+ from src .policies .epsilon_greedy_policy import EpsilonGreedyPolicy , EpsilonGreedyConfig
10+
11+ StateActionVec = TypeVar ('StateActionVec' )
12+ State = TypeVar ('State' )
13+ Action = TypeVar ('Action' )
14+ Env = TypeVar ('Env' )
15+
16+
17+ @dataclass (init = True , repr = True )
18+ class EpsilonGreedyQEstimatorConfig (EpsilonGreedyConfig ):
19+ gamma : float = 1.0
20+ alpha : float = 1.0
21+ env : Env = None
22+
623
724class EpsilonGreedyQEstimator (WithEstimatorMixin ):
8-
9- def __init__ (self ):
10- super (EpsilonGreedyQEstimator , self ).__init__ ()
25+ """Q-function estimator using an epsilon-greedy policy
26+ for action selection
27+ """
28+
29+ def __init__ (self , config : EpsilonGreedyQEstimatorConfig ):
30+ """Constructor
31+
32+ Parameters
33+ ----------
34+ config: The instance configuration
35+
36+ """
37+ super (EpsilonGreedyQEstimator , self ).__init__ ()
38+ self .eps_policy : EpsilonGreedyPolicy = EpsilonGreedyPolicy .from_config (config )
39+ self .alpha : float = config .alpha
40+ self .gamma : float = config .gamma
41+ self .env : Env = config .env
42+ self .weights : np .array = None
43+
44+ def q_hat_value (self , state_action_vec : StateActionVec ) -> float :
45+ """Returns the
46+ :math: \hat{q}
47+
48+ approximate value for the given state-action vector
49+ Parameters
50+ ----------
51+ state_action_vec
52+
53+ Returns
54+ -------
55+ float
56+
57+
58+ """
59+ return self .weights .dot (state_action_vec )
60+
61+ def update_weights (self , total_reward : float , state_action : Action ,
62+ state_action_ : Action , t : float ) -> None :
63+ """
64+ Update the weights
65+ Parameters
66+ ----------
67+ total_reward: The reward observed
68+ state_action: The action that led to the reward
69+ state_action_:
70+ t: The decay factor for alpha
71+
72+ Returns
73+ -------
74+
75+ None
76+
77+ """
78+ v1 = self .q_hat_value (state_action_vec = state_action )
79+ v2 = self .q_hat_value (state_action_vec = state_action_ )
80+ self .weights += self .alpha / t * (total_reward + self .gamma * v2 - v1 ) * state_action
81+
82+ def on_state (self , state : State ) -> Action :
83+ """Returns the action on the given state
84+ Parameters
85+ ----------
86+ state
87+
88+ Returns
89+ -------
90+
91+ """
92+
93+ # compute the state values related to
94+ # the given state
95+ q_values = []
96+
97+ for action in range (self .env .n_actions ):
98+ state_action_vector = self .env .get_state_action_tile (action = action , state = state )
99+ q_values .append (state_action_vector )
100+
101+ # choose an action at the current state
102+ action = self .eps_policy (q_values , state )
103+ return action
0 commit comments