diff --git a/init2winit/test_utils.py b/init2winit/test_utils.py index 6a90a6f2..6a87495c 100644 --- a/init2winit/test_utils.py +++ b/init2winit/test_utils.py @@ -97,15 +97,18 @@ def testAppendPytree(self): latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='') saved_pytrees = latest['pytree'] if latest else [] self.assertEqual( - pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))]) + pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))] + ) def testArrayAppend(self): """Test appending to an array.""" np.testing.assert_allclose( - utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4])) + utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4]) + ) np.testing.assert_allclose( utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])), - jnp.array([[1, 2], [3, 4], [5, 6]])) + jnp.array([[1, 2], [3, 4], [5, 6]]), + ) def testTreeNormSqL2(self): """Test computing the squared L2 norm of a pytree.""" @@ -115,9 +118,9 @@ def testTreeNormSqL2(self): def testTreeSum(self): """Test computing the sum of a pytree.""" - pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)} + pytree = {'foo': 2 * jnp.ones(10), 'baz': jnp.ones(20)} self.assertEqual(utils.total_tree_sum(pytree), 40) + if __name__ == '__main__': absltest.main() -