Skip to content

Commit 2d77f0d

Browse files
authored
Merge pull request #42 from pockerman/add_sarsa_semi_gradient
Refactorings
2 parents ae83c32 + a318ec3 commit 2d77f0d

File tree

5 files changed

+80
-56
lines changed

5 files changed

+80
-56
lines changed

src/algorithms/q_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
8989
env.step(action=action)
9090
total_dist = env.total_current_distortion()
9191

92-
def train(self, env: Env, **options) -> tuple:
92+
def on_episode(self, env: Env, **options) -> tuple:
9393

9494
# episode score
9595
episode_score = 0

src/algorithms/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def train(self):
7474
ignore = self.env.reset()
7575

7676
# train for a number of iterations
77-
episode_score, total_distortion, n_itrs = self.agent.train(self.env)
77+
episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)
7878

7979
print("{0} Episode score={1}, episode total distortion {2}".format(INFO, episode_score, total_distortion / n_itrs))
8080

src/spaces/column_type.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Simple enumeration of column types.
3+
This is similar to the ARX software.
4+
See the ARX documentation at:
5+
https://arx.deidentifier.org/wp-content/uploads/javadoc/current/api/org/deidentifier/arx/AttributeType.html
6+
"""
7+
8+
import enum
9+
10+
11+
class ColumnType(enum.IntEnum):
12+
13+
INVALID_TYPE = 0
14+
IDENTIFYING_ATTRIBUTE = 1
15+
SENSITIVE_ATTRIBUTE = 2
16+
INSENSITIVE_ATTRIBUTE = 3
17+
QUASI_IDENTIFYING_ATTRIBUTE = 4

src/spaces/discrete_state_environment.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,69 +4,19 @@
44
"""
55

66
import copy
7-
import enum
87
import numpy as np
98
from pathlib import Path
10-
import pandas as pd
11-
import torch
12-
from typing import NamedTuple, Generic, Optional, TypeVar, List
9+
from typing import TypeVar, List
1310
import multiprocessing as mp
1411

1512
from src.spaces.actions import ActionBase, ActionType
16-
from src.utils.string_distance_calculator import StringDistanceType, TextDistanceCalculator
17-
from src.utils.numeric_distance_type import NumericDistanceType
18-
from src.utils.numeric_distance_calculator import NumericDistanceCalculator
13+
from src.spaces.time_step import TimeStep, StepType
1914

2015
DataSet = TypeVar("DataSet")
2116
RewardManager = TypeVar("RewardManager")
2217
ActionSpace = TypeVar("ActionSpace")
2318
DistortionCalculator = TypeVar('DistortionCalculator')
2419

25-
_Reward = TypeVar('_Reward')
26-
_Discount = TypeVar('_Discount')
27-
_Observation = TypeVar('_Observation')
28-
29-
30-
class StepType(enum.IntEnum):
31-
"""
32-
Defines the status of a `TimeStep` within a sequence.
33-
"""
34-
35-
# Denotes the first `TimeStep` in a sequence.
36-
FIRST = 0
37-
38-
# Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
39-
MID = 1
40-
41-
# Denotes the last `TimeStep` in a sequence.
42-
LAST = 2
43-
44-
def first(self) -> bool:
45-
return self is StepType.FIRST
46-
47-
def mid(self) -> bool:
48-
return self is StepType.MID
49-
50-
def last(self) -> bool:
51-
return self is StepType.LAST
52-
53-
54-
class TimeStep(NamedTuple, Generic[_Reward, _Discount, _Observation]):
55-
step_type: StepType
56-
info: dict
57-
reward: Optional[_Reward]
58-
discount: Optional[_Discount]
59-
observation: _Observation
60-
61-
def first(self) -> bool:
62-
return self.step_type == StepType.FIRST
63-
64-
def mid(self) -> bool:
65-
return self.step_type == StepType.MID
66-
67-
def last(self) -> bool:
68-
return self.step_type == StepType.LAST
69-
7020

7121
class DiscreteEnvConfig(object):
7222
"""
@@ -79,8 +29,6 @@ def __init__(self) -> None:
7929
self.reward_manager: RewardManager = None
8030
self.average_distortion_constraint: float = 0.0
8131
self.gamma: float = 0.99
82-
# self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID
83-
# self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
8432
self.n_states: int = 10
8533
self.min_distortion: float = 0.4
8634
self.max_distortion: float = 0.7
@@ -115,6 +63,10 @@ def __init__(self, env_config: DiscreteEnvConfig) -> None:
11563
self.column_visits = {}
11664
self.create_bins()
11765

66+
@property
67+
def columns_attribute_types(self) -> dict:
68+
return self.config.data_set.columns_attribute_types
69+
11870
@property
11971
def action_space(self):
12072
return self.config.action_space

src/spaces/time_step.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
3+
"""
4+
5+
import enum
6+
from typing import NamedTuple, Generic, Optional, TypeVar
7+
8+
_Reward = TypeVar('_Reward')
9+
_Discount = TypeVar('_Discount')
10+
_Observation = TypeVar('_Observation')
11+
12+
13+
class StepType(enum.IntEnum):
14+
"""
15+
Defines the status of a `TimeStep` within a sequence.
16+
"""
17+
18+
# Denotes the first `TimeStep` in a sequence.
19+
FIRST = 0
20+
21+
# Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
22+
MID = 1
23+
24+
# Denotes the last `TimeStep` in a sequence.
25+
LAST = 2
26+
27+
def first(self) -> bool:
28+
return self is StepType.FIRST
29+
30+
def mid(self) -> bool:
31+
return self is StepType.MID
32+
33+
def last(self) -> bool:
34+
return self is StepType.LAST
35+
36+
37+
class TimeStep(NamedTuple, Generic[_Reward, _Discount, _Observation]):
38+
step_type: StepType
39+
info: dict
40+
reward: Optional[_Reward]
41+
discount: Optional[_Discount]
42+
observation: _Observation
43+
44+
def first(self) -> bool:
45+
return self.step_type == StepType.FIRST
46+
47+
def mid(self) -> bool:
48+
return self.step_type == StepType.MID
49+
50+
def last(self) -> bool:
51+
return self.step_type == StepType.LAST
52+
53+
@property
54+
def done(self) -> bool:
55+
return self.last()

0 commit comments

Comments
 (0)