44import torch .nn as nn
55import torch .nn .functional as F
66
7+ from src .utils .experience_buffer import unpack_batch
8+
79Env = TypeVar ("Env" )
810Optimizer = TypeVar ("Optimizer" )
911LossFunction = TypeVar ("LossFunction" )
@@ -53,6 +55,8 @@ def __init__(self):
5355 self .n_iterations_per_episode : int = 100
5456 self .optimizer : Optimizer = None
5557 self .loss_function : LossFunction = None
58+ self .batch_size : int = 0
59+ self .device : str = 'cpu'
5660
5761
5862class A2C (Generic [Optimizer ]):
@@ -63,15 +67,15 @@ def __init__(self, config: A2CConfig, a2c_net: A2CNet):
6367 self .tau = config .tau
6468 self .n_workers = config .n_workers
6569 self .n_iterations_per_episode = config .n_iterations_per_episode
70+ self .batch_size = config .batch_size
6671 self .optimizer = config .optimizer
72+ self .device = config .device
6773 self .loss_function = config .loss_function
6874 self .a2c_net = a2c_net
6975 self .rewards = []
76+ self .memory = []
7077 self .name = "A2C"
7178
72- def _optimize_model (self ):
73- pass
74-
7579 def select_action (self , env : Env , observation : State ) -> Action :
7680 """
7781 Select an action
@@ -81,17 +85,43 @@ def select_action(self, env: Env, observation: State) -> Action:
8185 """
8286 return env .sample_action ()
8387
84- def update (self ):
88+ def update_policy_network (self ):
89+ """
90+ Update the policy network
91+ :return:
92+ """
93+ pass
94+
95+ def calculate_loss (self ):
96+ """
97+ Calculate the loss
98+ :return:
99+ """
100+ pass
101+
102+ def accummulate_batch (self ):
103+ """
104+ Accumulate the memory items
105+ :return:
106+ """
85107 pass
86108
87109 def train (self , env : Env ) -> None :
110+ """
111+ Train the agent on the given environment
112+ :param env:
113+ :return:
114+ """
88115
89116 # reset the environment and obtain the
90117 # the time step
91118 time_step : TimeStep = env .reset ()
92119
93120 observation = time_step .observation
94121
122+ # the batch to process
123+ batch = []
124+
95125 # learn over the episode
96126 for iteration in range (1 , self .n_iterations_per_episode + 1 ):
97127
@@ -102,11 +132,27 @@ def train(self, env: Env) -> None:
102132 # to the selected action
103133 next_time_step = env .step (action = action )
104134
135+ batch .append (next_time_step .observation )
136+
137+ if len (batch ) < self .batch_size :
138+ continue
139+
140+ # unpack the batch in order to process it
141+ states_v , actions_t , vals_ref = unpack_batch (batch = batch , net = self .a2c_net , device = self .device )
142+ batch .clear ()
143+
144+ self .optimizer .zero_grad ()
105145 # we reached the end of the episode
106- if next_time_step .last ():
107- break
146+ #if next_time_step.last():
147+ # break
148+
149+ #next_state = next_time_step.observation
150+ policy_val , v_val = self .a2c_net .forward (x = states_v )
151+
152+ self .optimizer .zero_grad ()
108153
109- next_state = next_time_step .observation
110- policy_val , v_val = self .a2c_net .forward (x = next_state )
111- self ._optimize_model ()
154+ # claculate loss
155+ loss = self .calculate_loss ()
156+ loss .backward ()
157+ self .optimizer .step ()
112158
0 commit comments