44import torch .nn as nn
55import torch .nn .functional as F
66
7+ from src .utils .experience_buffer import unpack_batch
8+
79Env = TypeVar ("Env" )
810Optimizer = TypeVar ("Optimizer" )
911LossFunction = TypeVar ("LossFunction" )
@@ -53,6 +55,8 @@ def __init__(self):
5355 self .n_iterations_per_episode : int = 100
5456 self .optimizer : Optimizer = None
5557 self .loss_function : LossFunction = None
58+ self .batch_size : int = 0
59+ self .device : str = 'cpu'
5660
5761
5862class A2C (Generic [Optimizer ]):
@@ -63,10 +67,13 @@ def __init__(self, config: A2CConfig, a2c_net: A2CNet):
6367 self .tau = config .tau
6468 self .n_workers = config .n_workers
6569 self .n_iterations_per_episode = config .n_iterations_per_episode
70+ self .batch_size = config .batch_size
6671 self .optimizer = config .optimizer
72+ self .device = config .device
6773 self .loss_function = config .loss_function
6874 self .a2c_net = a2c_net
6975 self .rewards = []
76+ self .memory = []
7077 self .name = "A2C"
7178
7279 def _optimize_model (self ):
@@ -81,7 +88,17 @@ def select_action(self, env: Env, observation: State) -> Action:
8188 """
8289 return env .sample_action ()
8390
84- def update (self ):
91+ def update_policy_network (self ):
92+ """
93+ Update the policy network
94+ :return:
95+ """
96+ pass
97+
98+ def calculate_loss (self ):
99+ pass
100+
101+ def accummulate_batch (self ):
85102 pass
86103
87104 def train (self , env : Env ) -> None :
@@ -92,6 +109,9 @@ def train(self, env: Env) -> None:
92109
93110 observation = time_step .observation
94111
112+ # the batch to process
113+ batch = []
114+
95115 # learn over the episode
96116 for iteration in range (1 , self .n_iterations_per_episode + 1 ):
97117
@@ -102,11 +122,27 @@ def train(self, env: Env) -> None:
102122 # to the selected action
103123 next_time_step = env .step (action = action )
104124
125+ batch .append (next_time_step .observation )
126+
127+ if len (batch ) < self .batch_size :
128+ continue
129+
130+ # unpack the batch in order to process it
131+ states_v , actions_t , vals_ref = unpack_batch (batch = batch , net = self .a2c_net , device = self .device )
132+ batch .clear ()
133+
134+ self .optimizer .zero_grad ()
105135 # we reached the end of the episode
106- if next_time_step .last ():
107- break
136+ #if next_time_step.last():
137+ # break
138+
139+ #next_state = next_time_step.observation
140+ policy_val , v_val = self .a2c_net .forward (x = states_v )
141+
142+ self .optimizer .zero_grad ()
108143
109- next_state = next_time_step .observation
110- policy_val , v_val = self .a2c_net .forward (x = next_state )
111- self ._optimize_model ()
144+ # claculate loss
145+ loss = self .calculate_loss ()
146+ loss .backward ()
147+ self .optimizer .step ()
112148
0 commit comments