diff --git a/learned_optimization/research/distill/truncated_distill.py b/learned_optimization/research/distill/truncated_distill.py index f84a5ce..63e507d 100644 --- a/learned_optimization/research/distill/truncated_distill.py +++ b/learned_optimization/research/distill/truncated_distill.py @@ -176,7 +176,7 @@ def _multi_perturb(theta: T, key: chex.PRNGKey, std: float, def _fn(key): pos = common.sample_perturbations(theta, key, std=std) - p_theta = jax.tree_multimap(lambda t, a: t + a, theta, pos) + p_theta = jax.tree_map(lambda t, a: t + a, theta, pos) return p_theta keys = jax.random.split(key, num_samples)