From d63a9442bbe75753cd8a12fbc82dcf4ee5a1b043 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 26 Sep 2022 14:38:06 -0700 Subject: [PATCH] Change jax_jit_test to be a jax_test() under Bazel that works across backends. Make it pass under TPU if x64 types are enabled. PiperOrigin-RevId: 476994286 --- tests/BUILD | 31 +++---------------------------- tests/jax_jit_test.py | 42 ++++++++++++++++++------------------------ 2 files changed, 21 insertions(+), 52 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 145a6cc63..5ca0249fb 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -257,35 +257,10 @@ jax_test( ], ) -py_test( - name = "jax_jit_test_x32", - srcs = ["jax_jit_test.py"], - main = "jax_jit_test.py", - visibility = ["//visibility:private"], - deps = [ - "//jax", - "//jax:test_util", - ], -) - -py_test( - name = "jax_jit_test_x64", - srcs = ["jax_jit_test.py"], - args = ["--jax_enable_x64=true"], - main = "jax_jit_test.py", - visibility = ["//visibility:private"], - deps = [ - "//jax", - "//jax:test_util", - ], -) - -test_suite( +jax_test( name = "jax_jit_test", - tests = [ - "jax_jit_test_x32", - "jax_jit_test_x64", - ], + srcs = ["jax_jit_test.py"], + main = "jax_jit_test.py", ) py_test( diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index 7fa3cc62c..7f6bc35a7 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -19,7 +19,6 @@ from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import api -from jax._src import abstract_arrays from jax import dtypes from jax._src import lib as jaxlib from jax import numpy as jnp @@ -27,15 +26,7 @@ from jax._src import test_util as jtu from jax.config import config import numpy as np - -# It covers all JAX numpy types types except bfloat16 and numpy array. -# TODO(jblespiau): Add support for float0 in the C++ path. -_EXCLUDED_TYPES = [np.ndarray] - -_SCALAR_NUMPY_TYPES = [ - x for x in abstract_arrays.array_types if x not in _EXCLUDED_TYPES -] - +config.parse_flags_with_absl() def _cpp_device_put(value, device): return jaxlib.jax_jit.device_put(value, config.x64_enabled, device) @@ -52,7 +43,7 @@ class JaxJitTest(jtu.JaxTestCase): def test_device_put_on_numpy_scalars(self, device_put_function): device = jax.devices()[0] - for dtype in _SCALAR_NUMPY_TYPES: + for dtype in jtu.supported_dtypes(): value = dtype(0) output_buffer = device_put_function(value, device=device) @@ -66,7 +57,7 @@ class JaxJitTest(jtu.JaxTestCase): def test_device_put_on_numpy_arrays(self, device_put_function): device = jax.devices()[0] - for dtype in _SCALAR_NUMPY_TYPES: + for dtype in jtu.supported_dtypes(): value = np.zeros((3, 4), dtype=dtype) output_buffer = device_put_function(value, device=device) @@ -135,10 +126,12 @@ class JaxJitTest(jtu.JaxTestCase): self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype) # Complex - res = np.asarray(_cpp_device_put(1 + 1j, device)) - self.assertEqual(res, 1 + 1j) - self.assertEqual(res.dtype, complex_type) - self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype) + if not (config.x64_enabled and jtu.device_under_test() == "tpu"): + # No TPU support for complex128. + res = np.asarray(_cpp_device_put(1 + 1j, device)) + self.assertEqual(res, 1 + 1j) + self.assertEqual(res.dtype, complex_type) + self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype) def test_convert_int_overflow(self): with self.assertRaisesRegex( @@ -151,7 +144,7 @@ class JaxJitTest(jtu.JaxTestCase): jax_enable_x64 = config.x64_enabled # 1. Numpy scalar types - for dtype in _SCALAR_NUMPY_TYPES: + for dtype in jtu.supported_dtypes(): value = dtype(0) signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64) @@ -160,7 +153,7 @@ class JaxJitTest(jtu.JaxTestCase): self.assertFalse(signature.weak_type) # 2. Numpy arrays - for dtype in _SCALAR_NUMPY_TYPES: + for dtype in jtu.supported_dtypes(): value = np.zeros((3, 4), dtype=dtype) signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64) @@ -194,11 +187,13 @@ class JaxJitTest(jtu.JaxTestCase): self.assertEqual(signature.shape, ()) self.assertTrue(signature.weak_type) # Complex - signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64) - self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype) - self.assertEqual(signature.dtype, complex_type) - self.assertEqual(signature.shape, ()) - self.assertTrue(signature.weak_type) + if not (jax_enable_x64 and jtu.device_under_test() == "tpu"): + # No TPU support for complex128. + signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64) + self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype) + self.assertEqual(signature.dtype, complex_type) + self.assertEqual(signature.shape, ()) + self.assertTrue(signature.weak_type) def test_signature_support(self): jit = partial(api._jit, True) @@ -210,5 +205,4 @@ class JaxJitTest(jtu.JaxTestCase): if __name__ == "__main__": - jax.config.config_with_absl() absltest.main(testLoader=jtu.JaxTestLoader())