Skip to content

Commit 93362cc

Browse files
authored
Merge pull request #61 from pockerman/investigate_sarsa_semi_gradient
Investigate sarsa semi gradient
2 parents 225974a + 5f24dda commit 93362cc

35 files changed

+11213
-178
lines changed

.github/workflows/python-app.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This workflow will install Python dependencies, run tests and lint with a single version of Python
22
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
33

4-
name: Python application
4+
name: Data-Anonymity-RL
55

66
on:
77
push:
@@ -33,8 +33,9 @@ jobs:
3333
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
3434
- name: Test with unittest
3535
run: |
36-
pytest
36+
pytest tests/test_suite.py -v --junitxml="test_result.xml"
3737
- name: Upload Unit Test Results
38+
uses: EnricoMi/publish-unit-test-result-action@v1
3839
if: always()
3940
with:
40-
files: test-results/**/*.xml
41+
files: test_result.xml

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
[![Data-Anonymity-RL](https://github.com/pockerman/rl_anonymity_with_python/actions/workflows/python-app.yml/badge.svg)](https://github.com/pockerman/rl_anonymity_with_python/actions/workflows/python-app.yml)
3+
14
# RL anonymity (with Python)
25

36
An experimental effort to use reinforcement learning techniques for data anonymization.

docs/source/API/actions.rst

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,8 @@
33

44
.. automodule:: actions
55

6-
7-
8-
96

10-
11-
12-
13-
14-
15-
7+
168
.. rubric:: Classes
179

1810
.. autosummary::
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
epsilon\_greedy\_q\_estimator
2+
=============================
3+
4+
.. automodule:: epsilon_greedy_q_estimator
5+
6+
.. autoclass:: EpsilonGreedyQEstimatorConfig
7+
8+
.. autoclass:: EpsilonGreedyQEstimator
9+
:members: __init__, q_hat_value, update_weights, on_state

docs/source/API/state.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
state
2+
=====
3+
4+
.. automodule:: state
5+
6+
.. autoclass:: StateIterator
7+
:members: __init__, at, finished, __next__, __len__
8+
9+
.. autoclass:: State
10+
:members: __init__, __contains__, __iter__, __getitem__

docs/source/modules.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ API
55
:maxdepth: 4
66

77
API/actions
8+
API/state
9+
API/epsilon_greedy_q_estimator
810
generated/action_space
911
generated/q_estimator
1012
generated/q_learning
1113
generated/trainer
12-
generated/sarsa_semi_gradient
1314
generated/exceptions
1415
generated/action_space
15-
generated/actions
1616
generated/column_type
1717
generated/discrete_state_environment
1818
generated/observation_space
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Module epsilon_greedy_q_estimator. Implements
2+
a q-estimator by assuming linear function approximation
3+
4+
"""
5+
from typing import TypeVar
6+
import numpy as np
7+
from dataclasses import dataclass
8+
9+
from src.utils.mixins import WithEstimatorMixin
10+
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonGreedyConfig
11+
12+
StateActionVec = TypeVar('StateActionVec')
13+
State = TypeVar('State')
14+
Action = TypeVar('Action')
15+
Env = TypeVar('Env')
16+
17+
18+
@dataclass(init=True, repr=True)
19+
class EpsilonGreedyQEstimatorConfig(EpsilonGreedyConfig):
20+
gamma: float = 1.0
21+
alpha: float = 1.0
22+
env: Env = None
23+
24+
25+
class EpsilonGreedyQEstimator(WithEstimatorMixin):
26+
"""Q-function estimator using an epsilon-greedy policy
27+
for action selection
28+
"""
29+
30+
def __init__(self, config: EpsilonGreedyQEstimatorConfig):
31+
"""Constructor
32+
33+
Parameters
34+
----------
35+
36+
config: The instance configuration
37+
38+
"""
39+
super(EpsilonGreedyQEstimator, self).__init__()
40+
self.eps_policy: EpsilonGreedyPolicy = EpsilonGreedyPolicy.from_config(config)
41+
self.alpha: float = config.alpha
42+
self.gamma: float = config.gamma
43+
self.env: Env = config.env
44+
self.weights: np.array = None
45+
46+
def q_hat_value(self, state_action_vec: StateActionVec) -> float:
47+
"""Returns the
48+
:math: \hat{q}
49+
50+
approximate value for the given state-action vector
51+
52+
Parameters
53+
----------
54+
55+
state_action_vec: The state-action tiled vector
56+
57+
Returns
58+
-------
59+
float
60+
61+
62+
"""
63+
return self.weights.dot(state_action_vec)
64+
65+
def update_weights(self, total_reward: float, state_action: Action,
66+
state_action_: Action, t: float) -> None:
67+
"""
68+
Update the weights
69+
70+
Parameters
71+
----------
72+
73+
total_reward: The reward observed
74+
state_action: The action that led to the reward
75+
state_action_:
76+
t: The decay factor for alpha
77+
78+
Returns
79+
-------
80+
81+
None
82+
83+
"""
84+
v1 = self.q_hat_value(state_action_vec=state_action)
85+
v2 = self.q_hat_value(state_action_vec=state_action_)
86+
self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action
87+
88+
def on_state(self, state: State) -> Action:
89+
"""Returns the action on the given state
90+
91+
Parameters
92+
----------
93+
94+
state: The state observed
95+
96+
Returns
97+
-------
98+
99+
An environment specific Action type
100+
"""
101+
102+
# compute the state values related to
103+
# the given state
104+
q_values = []
105+
106+
for action in range(self.env.n_actions):
107+
state_action_vector = self.env.get_state_action_tile(action=action, state=state)
108+
q_values.append(state_action_vector)
109+
110+
# choose an action at the current state
111+
action = self.eps_policy(q_values, state)
112+
return action

src/algorithms/sarsa_semi_gradient.py renamed to src/algorithms/n_step_semi_gradient_sarsa.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
@dataclass(init=True, repr=True)
2222
class SARSAnConfig:
2323
"""Configuration class for n-step SARSA algorithm
24-
2524
"""
2625
gamma: float = 1.0
2726
alpha: float = 0.1
@@ -39,13 +38,45 @@ class SARSAn(WithMaxActionMixin):
3938
"""
4039

4140
def __init__(self, sarsa_config: SARSAnConfig):
42-
super(SARSAn, self).__init__()
41+
super(SARSAn, self).__init__(table={})
4342
self.name = "SARSAn"
4443
self.config = sarsa_config
45-
self.q_table = {}
4644

4745
def play(self, env: Env, stop_criterion: Criterion) -> None:
48-
pass
46+
"""
47+
Apply the trained agent on the given environment.
48+
49+
Parameters
50+
----------
51+
env: The environment to apply the agent
52+
stop_criterion: Criteria that specify when play should stop
53+
54+
Returns
55+
-------
56+
57+
None
58+
59+
"""
60+
# loop over the columns and for the
61+
# column get the action that corresponds to
62+
# the max payout.
63+
# TODO: This will no work as the distortion is calculated
64+
# by summing over the columns.
65+
66+
# set the q_table for the policy
67+
# this is the table we should be using to
68+
# make decisions
69+
self.config.policy.q_table = self.q_table
70+
total_dist = env.total_current_distortion()
71+
while stop_criterion.continue_itr(total_dist):
72+
# use the policy to select an action
73+
state_idx = env.get_aggregated_state(total_dist)
74+
action_idx = self.config.policy.on_state(state_idx)
75+
action = env.get_action(action_idx)
76+
print("{0} At state={1} with distortion={2} select action={3}".format("INFO: ", state_idx, total_dist,
77+
action.column_name + "-" + action.action_type.name))
78+
env.step(action=action)
79+
total_dist = env.total_current_distortion()
4980

5081
def actions_before_training(self, env: Env) -> None:
5182
"""

src/algorithms/q_learning.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
7777
total_dist = env.total_current_distortion()
7878
while stop_criterion.continue_itr(total_dist):
7979

80-
if stop_criterion.iteration_counter == 12:
81-
print("Break...")
82-
8380
# use the policy to select an action
8481
state_idx = env.get_aggregated_state(total_dist)
8582
action_idx = self.config.policy.on_state(state_idx)

0 commit comments

Comments
 (0)