From 91b219a491c7571cded2daac63a360927d4a62bf Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Thu, 23 May 2024 13:51:04 -0700 Subject: [PATCH] Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. PiperOrigin-RevId: 636665054 --- init2winit/test_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/init2winit/test_utils.py b/init2winit/test_utils.py index 6a90a6f2..6a87495c 100644 --- a/init2winit/test_utils.py +++ b/init2winit/test_utils.py @@ -97,15 +97,18 @@ def testAppendPytree(self): latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='') saved_pytrees = latest['pytree'] if latest else [] self.assertEqual( - pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))]) + pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))] + ) def testArrayAppend(self): """Test appending to an array.""" np.testing.assert_allclose( - utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4])) + utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4]) + ) np.testing.assert_allclose( utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])), - jnp.array([[1, 2], [3, 4], [5, 6]])) + jnp.array([[1, 2], [3, 4], [5, 6]]), + ) def testTreeNormSqL2(self): """Test computing the squared L2 norm of a pytree.""" @@ -115,9 +118,9 @@ def testTreeNormSqL2(self): def testTreeSum(self): """Test computing the sum of a pytree.""" - pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)} + pytree = {'foo': 2 * jnp.ones(10), 'baz': jnp.ones(20)} self.assertEqual(utils.total_tree_sum(pytree), 40) + if __name__ == '__main__': absltest.main() -