|
| 1 | +""" |
| 2 | +Simple Q-learning algorithm |
| 3 | +""" |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from typing import TypeVar |
| 7 | + |
| 8 | +from src.exceptions.exceptions import InvalidParamValue |
| 9 | +from src.utils.mixins import WithMaxActionMixin |
| 10 | + |
| 11 | +Env = TypeVar('Env') |
| 12 | +Policy = TypeVar('Policy') |
| 13 | + |
| 14 | +class QLearnConfig(object): |
| 15 | + |
| 16 | + def __init__(self): |
| 17 | + self.gamma: float = 1.0 |
| 18 | + self.alpha: float = 0.1 |
| 19 | + self.n_itrs_per_episode: int = 100 |
| 20 | + self.policy: Policy = None |
| 21 | + |
| 22 | + |
| 23 | +class QLearning(WithMaxActionMixin): |
| 24 | + |
| 25 | + def __init__(self, algo_config: QLearnConfig): |
| 26 | + super(QLearning, self).__init__() |
| 27 | + self.q_table = {} |
| 28 | + self.config = algo_config |
| 29 | + |
| 30 | + # monitor performance |
| 31 | + self.total_rewards: np.array = None |
| 32 | + self.iterations_per_episode = [] |
| 33 | + |
| 34 | + @property |
| 35 | + def name(self) -> str: |
| 36 | + return "QLearn" |
| 37 | + |
| 38 | + def actions_before_training(self, env: Env, **options): |
| 39 | + |
| 40 | + if self.config.policy is None: |
| 41 | + raise InvalidParamValue(param_name="policy", param_value="None") |
| 42 | + |
| 43 | + for state in range(env.observation_space.n): |
| 44 | + for action in range(env.action_space.n): |
| 45 | + self.q_table[state, action] = 0.0 |
| 46 | + |
| 47 | + def actions_after_episode_ends(self, **options): |
| 48 | + """ |
| 49 | + Execute any actions the algorithm needs before |
| 50 | + starting the episode |
| 51 | + :param options: |
| 52 | + :return: |
| 53 | + """ |
| 54 | + |
| 55 | + self.config.policy.actions_after_episode(options['episode_idx']) |
| 56 | + |
| 57 | + def train(self, env: Env, **options) -> tuple: |
| 58 | + |
| 59 | + # episode score |
| 60 | + episode_score = 0 # initialize score |
| 61 | + counter = 0 |
| 62 | + |
| 63 | + time_step = env.reset() |
| 64 | + state = time_step.observation |
| 65 | + |
| 66 | + for itr in range(self.config.n_itrs_per_episode): |
| 67 | + |
| 68 | + # epsilon-greedy action selection |
| 69 | + action_idx = self.config.policy(q_func=self.q_table, state=state) |
| 70 | + |
| 71 | + action = env.get_action(action_idx) |
| 72 | + |
| 73 | + # take action A, observe R, S' |
| 74 | + next_time_step = env.step(action) |
| 75 | + next_state = next_time_step.observation |
| 76 | + reward = next_time_step.reward |
| 77 | + |
| 78 | + next_state_id = next_state.state_id if next_state is not None else None |
| 79 | + |
| 80 | + # add reward to agent's score |
| 81 | + episode_score += next_time_step.reward |
| 82 | + self._update_Q_table(state=state.state_id, action=action_idx, reward=reward, |
| 83 | + next_state=next_state_id, n_actions=env.action_space.n) |
| 84 | + state = next_state # S <- S' |
| 85 | + counter += 1 |
| 86 | + |
| 87 | + if next_time_step.last(): |
| 88 | + break |
| 89 | + |
| 90 | + return episode_score, counter |
| 91 | + |
| 92 | + def _update_Q_table(self, state: int, action: int, n_actions: int, reward: float, next_state: int = None) -> None: |
| 93 | + """ |
| 94 | + Update the Q-value for the state |
| 95 | + """ |
| 96 | + |
| 97 | + # estimate in Q-table (for current state, action pair) |
| 98 | + q_s = self.q_table[state, action] |
| 99 | + |
| 100 | + # value of next state |
| 101 | + Qsa_next = \ |
| 102 | + self.q_table[next_state, self.max_action(next_state, n_actions=n_actions)] if next_state is not None else 0 |
| 103 | + # construct TD target |
| 104 | + target = reward + (self.config.gamma * Qsa_next) |
| 105 | + |
| 106 | + # get updated value |
| 107 | + new_value = q_s + (self.config.alpha * (target - q_s)) |
| 108 | + self.q_table[state, action] = new_value |
0 commit comments