-
-
Notifications
You must be signed in to change notification settings - Fork 167
Description
I am using diffrax to solve an ODE. The ODE has a free parameter ,so I use jax.vmap to map the integration over different parameter values. Paired with jit, this is already quite fast on a single GPU, but I have multiple GPUs that I would like to take advantage of. I attempted to use Jax's automatic parallelization by sharding the parameter array across my available GPUs and replicating the initial position on each GPU. Unfortunately, this takes 2 orders of magnitude longer to compile and run than the version without sharding. Is there something obvious that I am doing wrong in the below code or a reason why this procedure performs poorly? I've tested it on very simple computations and it seems to show marginal improvements than vmap alone, so I'm wondering if there is something internal to diffrax that could be causing the slow down?
from jax.sharding import PartitionSpec as P, NamedSharding
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from diffrax import ODETerm, ImplicitEuler, PIDController, diffeqsolve
import timeit
def ode(t, y, args):
"""dy/dt = -50 * (y - cos(t))"""
x = args
return x * (y - jnp.cos(t))
class integrator():
def __init__(self, vec_field):
self.solver = ImplicitEuler()
self.stepsize_controller = PIDController(rtol = 1e-3, atol = 1e-6)
self.term = ODETerm(vec_field)
def __call__(self, y_0, tmax, vec_field_args, tmin = 0., dt = None, max_steps = 4096):
sol = diffeqsolve(self.term, self.solver, tmin, tmax, dt, y_0, args = vec_field_args, stepsize_controller=self.stepsize_controller, max_steps = max_steps)
return sol.ys[0]
integ = integrator(ode)
int_vmap = jax.jit(jax.vmap(integ, in_axes = (None, None, 0)))Regular version:
params = jnp.linspace(-50, -1, 12)
init = jnp.linspace(0.1, 1., 10)
time = timeit.repeat('int_vmap(init, 10., params).block_until_ready()',
globals = globals(),
number = 2,
repeat = 2)
print(f'took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run')took 1.80e+00 seconds to compile, 1.25e-01 to run
Auto-parallelization version:
mesh = jax.make_mesh((4,), ('x'))
params = jax.device_put(jnp.linspace(-50, -1, 12), NamedSharding(mesh, P('x')))
init = jax.device_put(jnp.linspace(0.1, 1., 10), NamedSharding(mesh, P()))
time = timeit.repeat('int_vmap(init, 10., params).block_until_ready()',
globals = globals(),
number = 2,
repeat = 2)
print(f'took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run')took 1.01e+02 seconds to compile, 8.66e+01 to run