Skip to content

auto-parallelization alongside vmap is very slow #718

@jaxengodfrey

Description

@jaxengodfrey

I am using diffrax to solve an ODE. The ODE has a free parameter ,so I use jax.vmap to map the integration over different parameter values. Paired with jit, this is already quite fast on a single GPU, but I have multiple GPUs that I would like to take advantage of. I attempted to use Jax's automatic parallelization by sharding the parameter array across my available GPUs and replicating the initial position on each GPU. Unfortunately, this takes 2 orders of magnitude longer to compile and run than the version without sharding. Is there something obvious that I am doing wrong in the below code or a reason why this procedure performs poorly? I've tested it on very simple computations and it seems to show marginal improvements than vmap alone, so I'm wondering if there is something internal to diffrax that could be causing the slow down?

from jax.sharding import PartitionSpec as P, NamedSharding
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)

from diffrax import ODETerm, ImplicitEuler, PIDController, diffeqsolve
import timeit

def ode(t, y, args):
    """dy/dt = -50 * (y - cos(t))"""
    x = args
    return x * (y - jnp.cos(t))

class integrator():
    def __init__(self, vec_field):
        self.solver = ImplicitEuler()
        self.stepsize_controller = PIDController(rtol = 1e-3, atol = 1e-6)
        self.term = ODETerm(vec_field)

    def __call__(self, y_0, tmax, vec_field_args, tmin = 0., dt = None, max_steps = 4096):
        sol = diffeqsolve(self.term, self.solver, tmin, tmax, dt, y_0, args = vec_field_args, stepsize_controller=self.stepsize_controller, max_steps = max_steps)
        return sol.ys[0]

integ = integrator(ode)
int_vmap = jax.jit(jax.vmap(integ, in_axes = (None, None, 0)))

Regular version:

params = jnp.linspace(-50, -1, 12)
init = jnp.linspace(0.1, 1., 10)

time = timeit.repeat('int_vmap(init, 10., params).block_until_ready()',
                     globals = globals(),
                    number = 2,
                    repeat = 2)

print(f'took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run')
took 1.80e+00 seconds to compile, 1.25e-01 to run

Auto-parallelization version:

mesh = jax.make_mesh((4,), ('x'))
params = jax.device_put(jnp.linspace(-50, -1, 12), NamedSharding(mesh, P('x')))
init = jax.device_put(jnp.linspace(0.1, 1., 10), NamedSharding(mesh, P()))

time = timeit.repeat('int_vmap(init, 10., params).block_until_ready()',
                     globals = globals(),
                    number = 2,
                    repeat = 2)

print(f'took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run')
took 1.01e+02 seconds to compile, 8.66e+01 to run

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions