From 61f6d77697a39cb66eabe4172db9b6692baff10d Mon Sep 17 00:00:00 2001 From: init2winit Team Date: Fri, 30 May 2025 21:15:10 -0700 Subject: [PATCH] extending bubbles to general tensors PiperOrigin-RevId: 765452194 --- init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py index 0579dd47..9618adc6 100644 --- a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py +++ b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py @@ -142,8 +142,13 @@ def pth_inv_root_rmn( x = x.astype(jnp.float32) n = x.shape[-1] - alpha = jax.lax.sqrt(jnp.linalg.norm(x, ord=1)) - alpha *= jax.lax.sqrt(jnp.linalg.norm(x, ord=jnp.inf)) + # Based on Gelfand's inequality, the lines below provide a + # tighter upper bound on the norm of the matrix. + # \sigma_max \leq (\parallel X X^T \parallel_F^k)^{1/2k} + xx = x @ x.T + alpha = jnp.power(jnp.linalg.norm(xx @ xx), 0.25) + # alpha = jax.lax.sqrt(jnp.linalg.norm(x, ord=1)) + # alpha *= jax.lax.sqrt(jnp.linalg.norm(x, ord=jnp.inf)) alpha = lax.select(alpha == 0, jnp.ones_like(alpha), alpha) beta = _scalar_inverse_root(alpha, p)