Skip to content

Commit 33b22fe

Browse files
authored
Merge pull request #10 from pockerman/add_actor_critic_algorithm
Towards actor critic algorithm
2 parents e6816d3 + a26442c commit 33b22fe

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

src/algorithms/a2c.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7+
from src.utils.experience_buffer import unpack_batch
8+
79
Env = TypeVar("Env")
810
Optimizer = TypeVar("Optimizer")
911
LossFunction = TypeVar("LossFunction")
@@ -53,6 +55,8 @@ def __init__(self):
5355
self.n_iterations_per_episode: int = 100
5456
self.optimizer: Optimizer = None
5557
self.loss_function: LossFunction = None
58+
self.batch_size: int = 0
59+
self.device: str = 'cpu'
5660

5761

5862
class A2C(Generic[Optimizer]):
@@ -63,15 +67,15 @@ def __init__(self, config: A2CConfig, a2c_net: A2CNet):
6367
self.tau = config.tau
6468
self.n_workers = config.n_workers
6569
self.n_iterations_per_episode = config.n_iterations_per_episode
70+
self.batch_size = config.batch_size
6671
self.optimizer = config.optimizer
72+
self.device = config.device
6773
self.loss_function = config.loss_function
6874
self.a2c_net = a2c_net
6975
self.rewards = []
76+
self.memory = []
7077
self.name = "A2C"
7178

72-
def _optimize_model(self):
73-
pass
74-
7579
def select_action(self, env: Env, observation: State) -> Action:
7680
"""
7781
Select an action
@@ -81,17 +85,43 @@ def select_action(self, env: Env, observation: State) -> Action:
8185
"""
8286
return env.sample_action()
8387

84-
def update(self):
88+
def update_policy_network(self):
89+
"""
90+
Update the policy network
91+
:return:
92+
"""
93+
pass
94+
95+
def calculate_loss(self):
96+
"""
97+
Calculate the loss
98+
:return:
99+
"""
100+
pass
101+
102+
def accummulate_batch(self):
103+
"""
104+
Accumulate the memory items
105+
:return:
106+
"""
85107
pass
86108

87109
def train(self, env: Env) -> None:
110+
"""
111+
Train the agent on the given environment
112+
:param env:
113+
:return:
114+
"""
88115

89116
# reset the environment and obtain the
90117
# the time step
91118
time_step: TimeStep = env.reset()
92119

93120
observation = time_step.observation
94121

122+
# the batch to process
123+
batch = []
124+
95125
# learn over the episode
96126
for iteration in range(1, self.n_iterations_per_episode + 1):
97127

@@ -102,11 +132,27 @@ def train(self, env: Env) -> None:
102132
# to the selected action
103133
next_time_step = env.step(action=action)
104134

135+
batch.append(next_time_step.observation)
136+
137+
if len(batch) < self.batch_size:
138+
continue
139+
140+
# unpack the batch in order to process it
141+
states_v, actions_t, vals_ref = unpack_batch(batch=batch, net=self.a2c_net, device=self.device)
142+
batch.clear()
143+
144+
self.optimizer.zero_grad()
105145
# we reached the end of the episode
106-
if next_time_step.last():
107-
break
146+
#if next_time_step.last():
147+
# break
148+
149+
#next_state = next_time_step.observation
150+
policy_val, v_val = self.a2c_net.forward(x=states_v)
151+
152+
self.optimizer.zero_grad()
108153

109-
next_state = next_time_step.observation
110-
policy_val, v_val = self.a2c_net.forward(x=next_state)
111-
self._optimize_model()
154+
# claculate loss
155+
loss = self.calculate_loss()
156+
loss.backward()
157+
self.optimizer.step()
112158

src/utils/experience_buffer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
from typing import TypeVar
3+
4+
Net = TypeVar('Net')
5+
Batch = TypeVar('Batch')
6+
7+
def unpack_batch(batch, net: Net, device: str='cpu'):
8+
pass

src/utils/serial_hierarchy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class SerialHierarchy(HierarchyBase):
4343
that are applied one after the other. Applications should explicitly
4444
provide the list of the ensuing transformations. For example assume that the
4545
data field has the value 'foo' then values
46-
the following list ['fo*', 'f**', '***']
46+
the following list ['fo*', 'f**', '***']
4747
"""
4848
def __init__(self, values: List) -> None:
4949
"""

0 commit comments

Comments
 (0)