|
| 1 | +"""Module epsilon_greedy_q_estimator. Implements |
| 2 | +a q-estimator by assuming linear function approximation |
| 3 | +
|
| 4 | +""" |
| 5 | +from typing import TypeVar |
| 6 | +import numpy as np |
| 7 | +from dataclasses import dataclass |
| 8 | + |
| 9 | +from src.utils.mixins import WithEstimatorMixin |
| 10 | +from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonGreedyConfig |
| 11 | + |
| 12 | +StateActionVec = TypeVar('StateActionVec') |
| 13 | +State = TypeVar('State') |
| 14 | +Action = TypeVar('Action') |
| 15 | +Env = TypeVar('Env') |
| 16 | + |
| 17 | + |
| 18 | +@dataclass(init=True, repr=True) |
| 19 | +class EpsilonGreedyQEstimatorConfig(EpsilonGreedyConfig): |
| 20 | + gamma: float = 1.0 |
| 21 | + alpha: float = 1.0 |
| 22 | + env: Env = None |
| 23 | + |
| 24 | + |
| 25 | +class EpsilonGreedyQEstimator(WithEstimatorMixin): |
| 26 | + """Q-function estimator using an epsilon-greedy policy |
| 27 | + for action selection |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__(self, config: EpsilonGreedyQEstimatorConfig): |
| 31 | + """Constructor |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | +
|
| 36 | + config: The instance configuration |
| 37 | +
|
| 38 | + """ |
| 39 | + super(EpsilonGreedyQEstimator, self).__init__() |
| 40 | + self.eps_policy: EpsilonGreedyPolicy = EpsilonGreedyPolicy.from_config(config) |
| 41 | + self.alpha: float = config.alpha |
| 42 | + self.gamma: float = config.gamma |
| 43 | + self.env: Env = config.env |
| 44 | + self.weights: np.array = None |
| 45 | + |
| 46 | + def q_hat_value(self, state_action_vec: StateActionVec) -> float: |
| 47 | + """Returns the |
| 48 | + :math: \hat{q} |
| 49 | +
|
| 50 | + approximate value for the given state-action vector |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | +
|
| 55 | + state_action_vec: The state-action tiled vector |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + float |
| 60 | +
|
| 61 | +
|
| 62 | + """ |
| 63 | + return self.weights.dot(state_action_vec) |
| 64 | + |
| 65 | + def update_weights(self, total_reward: float, state_action: Action, |
| 66 | + state_action_: Action, t: float) -> None: |
| 67 | + """ |
| 68 | + Update the weights |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | +
|
| 73 | + total_reward: The reward observed |
| 74 | + state_action: The action that led to the reward |
| 75 | + state_action_: |
| 76 | + t: The decay factor for alpha |
| 77 | +
|
| 78 | + Returns |
| 79 | + ------- |
| 80 | +
|
| 81 | + None |
| 82 | +
|
| 83 | + """ |
| 84 | + v1 = self.q_hat_value(state_action_vec=state_action) |
| 85 | + v2 = self.q_hat_value(state_action_vec=state_action_) |
| 86 | + self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action |
| 87 | + |
| 88 | + def on_state(self, state: State) -> Action: |
| 89 | + """Returns the action on the given state |
| 90 | +
|
| 91 | + Parameters |
| 92 | + ---------- |
| 93 | +
|
| 94 | + state: The state observed |
| 95 | +
|
| 96 | + Returns |
| 97 | + ------- |
| 98 | +
|
| 99 | + An environment specific Action type |
| 100 | + """ |
| 101 | + |
| 102 | + # compute the state values related to |
| 103 | + # the given state |
| 104 | + q_values = [] |
| 105 | + |
| 106 | + for action in range(self.env.n_actions): |
| 107 | + state_action_vector = self.env.get_state_action_tile(action=action, state=state) |
| 108 | + q_values.append(state_action_vector) |
| 109 | + |
| 110 | + # choose an action at the current state |
| 111 | + action = self.eps_policy(q_values, state) |
| 112 | + return action |
0 commit comments