@@ -30,6 +30,7 @@ mutable struct GymEnv{T} <: AbstractGymEnv
3030 pyreset:: PyObject # the python env.reset function
3131 pystate:: PyObject # the state array object referenced by the PyArray state.o
3232 pystepres:: PyObject # used to make stepping the env slightly more efficient
33+ pytplres:: PyObject # used to make stepping the env slightly more efficient
3334 info:: PyObject # store it as a PyObject for speed, since often unused
3435 state:: T
3536 reward:: Float64
@@ -40,7 +41,7 @@ mutable struct GymEnv{T} <: AbstractGymEnv
4041 pystate = pycall (pyenv[" reset" ], PyObject)
4142 state = convert (stateT, pystate)
4243 env = new {typeof(state)} (name, pyenv, pyenv[" step" ], pyenv[" reset" ],
43- pystate, PyNULL (), PyNULL (), state)
44+ pystate, PyNULL (), PyNULL (), PyNULL (), state)
4445 reset! (env)
4546 env
4647 end
@@ -137,13 +138,10 @@ function Reinforce.step!(env::GymEnv{T}, a) where T <: PyArray
137138 pyact = pyaction (a)
138139 pycall! (env. pystepres, env. pystep, PyObject, pyact)
139140
140- env. pystate, r, env. done, env. info =
141- convert (Tuple{PyObject, Float64, Bool, PyObject}, env. pystepres)
142-
141+ unsafe_gettpl! (env. pystate, env. pystepres, PyObject, 0 )
143142 setdata! (env. state, env. pystate)
144143
145- env. total_reward += r
146- return (r, env. state)
144+ return gymstep! (env)
147145end
148146
149147"""
@@ -153,11 +151,16 @@ function Reinforce.step!(env::GymEnv{T}, a) where T
153151 pyact = pyaction (a)
154152 pycall! (env. pystepres, env. pystep, PyObject, pyact)
155153
156- env. pystate, r, env. done, env. info =
157- convert (Tuple{PyObject, Float64, Bool, PyObject}, env. pystepres)
158-
154+ unsafe_gettpl! (env. pystate, env. pystepres, PyObject, 0 )
159155 env. state = convert (T, env. pystate)
160156
157+ return gymstep! (env)
158+ end
159+
160+ @inline function gymstep! (env)
161+ r = unsafe_gettpl! (env. pytplres, env. pystepres, Float64, 1 )
162+ env. done = unsafe_gettpl! (env. pytplres, env. pystepres, Bool, 2 )
163+ unsafe_gettpl! (env. info, env. pystepres, PyObject, 3 )
161164 env. total_reward += r
162165 return (r, env. state)
163166end
0 commit comments