From 6efb03cf0d6f2f830307d8022ead7d9301e7819b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 14 Jun 2022 15:14:44 -0700 Subject: [PATCH] [x64] make tree_util_test compatible with strict dtype promotion --- tests/tree_util_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 1141f084e..d8d4748f3 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -19,6 +19,7 @@ import re from absl.testing import absltest from absl.testing import parameterized +import jax from jax import tree_util from jax import flatten_util from jax._src import test_util as jtu @@ -384,6 +385,7 @@ class RavelUtilTest(jtu.JaxTestCase): tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) + @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testMixedFloatInt(self): tree = [jnp.array([3], jnp.int32), jnp.array([[1., 2.], [3., 4.]], jnp.float32)] @@ -392,6 +394,7 @@ class RavelUtilTest(jtu.JaxTestCase): tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) + @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testMixedIntBool(self): tree = [jnp.array([0], jnp.bool_), jnp.array([[1, 2], [3, 4]], jnp.int32)] @@ -400,6 +403,7 @@ class RavelUtilTest(jtu.JaxTestCase): tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.) + @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testMixedFloatComplex(self): tree = [jnp.array([1.], jnp.float32), jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)] @@ -423,6 +427,7 @@ class RavelUtilTest(jtu.JaxTestCase): x_ = unravel(y) self.assertEqual(x_.dtype, y.dtype) + @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testDtypeMonomorphicUnravel(self): # https://github.com/google/jax/issues/7809 x1 = jnp.arange(10, dtype=jnp.float32)