Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.

Commit d16ea68

Browse files
committed
Use unsafe_gettpl! to speed up access to results of env.step()
Requires a PyCall PR
1 parent fc5c53e commit d16ea68

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/OpenAIGym.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ mutable struct GymEnv{T} <: AbstractGymEnv
3030
pyreset::PyObject # the python env.reset function
3131
pystate::PyObject # the state array object referenced by the PyArray state.o
3232
pystepres::PyObject # used to make stepping the env slightly more efficient
33+
pytplres::PyObject # used to make stepping the env slightly more efficient
3334
info::PyObject # store it as a PyObject for speed, since often unused
3435
state::T
3536
reward::Float64
@@ -40,7 +41,7 @@ mutable struct GymEnv{T} <: AbstractGymEnv
4041
pystate = pycall(pyenv["reset"], PyObject)
4142
state = convert(stateT, pystate)
4243
env = new{typeof(state)}(name, pyenv, pyenv["step"], pyenv["reset"],
43-
pystate, PyNULL(), PyNULL(), state)
44+
pystate, PyNULL(), PyNULL(), PyNULL(), state)
4445
reset!(env)
4546
env
4647
end
@@ -137,13 +138,10 @@ function Reinforce.step!(env::GymEnv{T}, a) where T <: PyArray
137138
pyact = pyaction(a)
138139
pycall!(env.pystepres, env.pystep, PyObject, pyact)
139140

140-
env.pystate, r, env.done, env.info =
141-
convert(Tuple{PyObject, Float64, Bool, PyObject}, env.pystepres)
142-
141+
unsafe_gettpl!(env.pystate, env.pystepres, PyObject, 0)
143142
setdata!(env.state, env.pystate)
144143

145-
env.total_reward += r
146-
return (r, env.state)
144+
return gymstep!(env)
147145
end
148146

149147
"""
@@ -153,11 +151,16 @@ function Reinforce.step!(env::GymEnv{T}, a) where T
153151
pyact = pyaction(a)
154152
pycall!(env.pystepres, env.pystep, PyObject, pyact)
155153

156-
env.pystate, r, env.done, env.info =
157-
convert(Tuple{PyObject, Float64, Bool, PyObject}, env.pystepres)
158-
154+
unsafe_gettpl!(env.pystate, env.pystepres, PyObject, 0)
159155
env.state = convert(T, env.pystate)
160156

157+
return gymstep!(env)
158+
end
159+
160+
@inline function gymstep!(env)
161+
r = unsafe_gettpl!(env.pytplres, env.pystepres, Float64, 1)
162+
env.done = unsafe_gettpl!(env.pytplres, env.pystepres, Bool, 2)
163+
unsafe_gettpl!(env.info, env.pystepres, PyObject, 3)
161164
env.total_reward += r
162165
return (r, env.state)
163166
end

0 commit comments

Comments
 (0)