Skip to content

Commit 4179668

Browse files
committed
Make progress with A2C algorithm
1 parent b1039e3 commit 4179668

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

src/algorithms/a2c.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7+
from src.utils.experience_buffer import unpack_batch
8+
79
Env = TypeVar("Env")
810
Optimizer = TypeVar("Optimizer")
911
LossFunction = TypeVar("LossFunction")
@@ -53,6 +55,8 @@ def __init__(self):
5355
self.n_iterations_per_episode: int = 100
5456
self.optimizer: Optimizer = None
5557
self.loss_function: LossFunction = None
58+
self.batch_size: int = 0
59+
self.device: str = 'cpu'
5660

5761

5862
class A2C(Generic[Optimizer]):
@@ -63,10 +67,13 @@ def __init__(self, config: A2CConfig, a2c_net: A2CNet):
6367
self.tau = config.tau
6468
self.n_workers = config.n_workers
6569
self.n_iterations_per_episode = config.n_iterations_per_episode
70+
self.batch_size = config.batch_size
6671
self.optimizer = config.optimizer
72+
self.device = config.device
6773
self.loss_function = config.loss_function
6874
self.a2c_net = a2c_net
6975
self.rewards = []
76+
self.memory = []
7077
self.name = "A2C"
7178

7279
def _optimize_model(self):
@@ -81,7 +88,17 @@ def select_action(self, env: Env, observation: State) -> Action:
8188
"""
8289
return env.sample_action()
8390

84-
def update(self):
91+
def update_policy_network(self):
92+
"""
93+
Update the policy network
94+
:return:
95+
"""
96+
pass
97+
98+
def calculate_loss(self):
99+
pass
100+
101+
def accummulate_batch(self):
85102
pass
86103

87104
def train(self, env: Env) -> None:
@@ -92,6 +109,9 @@ def train(self, env: Env) -> None:
92109

93110
observation = time_step.observation
94111

112+
# the batch to process
113+
batch = []
114+
95115
# learn over the episode
96116
for iteration in range(1, self.n_iterations_per_episode + 1):
97117

@@ -102,11 +122,27 @@ def train(self, env: Env) -> None:
102122
# to the selected action
103123
next_time_step = env.step(action=action)
104124

125+
batch.append(next_time_step.observation)
126+
127+
if len(batch) < self.batch_size:
128+
continue
129+
130+
# unpack the batch in order to process it
131+
states_v, actions_t, vals_ref = unpack_batch(batch=batch, net=self.a2c_net, device=self.device)
132+
batch.clear()
133+
134+
self.optimizer.zero_grad()
105135
# we reached the end of the episode
106-
if next_time_step.last():
107-
break
136+
#if next_time_step.last():
137+
# break
138+
139+
#next_state = next_time_step.observation
140+
policy_val, v_val = self.a2c_net.forward(x=states_v)
141+
142+
self.optimizer.zero_grad()
108143

109-
next_state = next_time_step.observation
110-
policy_val, v_val = self.a2c_net.forward(x=next_state)
111-
self._optimize_model()
144+
# claculate loss
145+
loss = self.calculate_loss()
146+
loss.backward()
147+
self.optimizer.step()
112148

src/utils/serial_hierarchy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class SerialHierarchy(HierarchyBase):
4343
that are applied one after the other. Applications should explicitly
4444
provide the list of the ensuing transformations. For example assume that the
4545
data field has the value 'foo' then values
46-
the following list ['fo*', 'f**', '***']
46+
the following list ['fo*', 'f**', '***']
4747
"""
4848
def __init__(self, values: List) -> None:
4949
"""

0 commit comments

Comments
 (0)