From 48f183364ecd193d02daf041d5599512488c14c9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Jul 2022 13:36:05 -0700 Subject: [PATCH] Use jax.tree_util.tree_map in place of deprecated tree_multimap. PiperOrigin-RevId: 462218807 --- learned_optimization/research/distill/truncated_distill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learned_optimization/research/distill/truncated_distill.py b/learned_optimization/research/distill/truncated_distill.py index f84a5ce..63e507d 100644 --- a/learned_optimization/research/distill/truncated_distill.py +++ b/learned_optimization/research/distill/truncated_distill.py @@ -176,7 +176,7 @@ def _multi_perturb(theta: T, key: chex.PRNGKey, std: float, def _fn(key): pos = common.sample_perturbations(theta, key, std=std) - p_theta = jax.tree_multimap(lambda t, a: t + a, theta, pos) + p_theta = jax.tree_map(lambda t, a: t + a, theta, pos) return p_theta keys = jax.random.split(key, num_samples)