[x64] make tree_util_test compatible with strict dtype promotion

This commit is contained in:
Jake VanderPlas 2022-06-14 15:14:44 -07:00
parent cd565f8f41
commit 6efb03cf0d

View File

@ -19,6 +19,7 @@ import re
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import tree_util
from jax import flatten_util
from jax._src import test_util as jtu
@ -384,6 +385,7 @@ class RavelUtilTest(jtu.JaxTestCase):
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testMixedFloatInt(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
@ -392,6 +394,7 @@ class RavelUtilTest(jtu.JaxTestCase):
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testMixedIntBool(self):
tree = [jnp.array([0], jnp.bool_),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
@ -400,6 +403,7 @@ class RavelUtilTest(jtu.JaxTestCase):
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testMixedFloatComplex(self):
tree = [jnp.array([1.], jnp.float32),
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
@ -423,6 +427,7 @@ class RavelUtilTest(jtu.JaxTestCase):
x_ = unravel(y)
self.assertEqual(x_.dtype, y.dtype)
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testDtypeMonomorphicUnravel(self):
# https://github.com/google/jax/issues/7809
x1 = jnp.arange(10, dtype=jnp.float32)