Change jax_jit_test to be a jax_test() under Bazel that works across backends.

Make it pass under TPU if x64 types are enabled.

PiperOrigin-RevId: 476994286
This commit is contained in:
Peter Hawkins 2022-09-26 14:38:06 -07:00 committed by jax authors
parent 265b39d23f
commit d63a9442bb
2 changed files with 21 additions and 52 deletions

View File

@ -257,35 +257,10 @@ jax_test(
],
)
py_test(
name = "jax_jit_test_x32",
srcs = ["jax_jit_test.py"],
main = "jax_jit_test.py",
visibility = ["//visibility:private"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "jax_jit_test_x64",
srcs = ["jax_jit_test.py"],
args = ["--jax_enable_x64=true"],
main = "jax_jit_test.py",
visibility = ["//visibility:private"],
deps = [
"//jax",
"//jax:test_util",
],
)
test_suite(
jax_test(
name = "jax_jit_test",
tests = [
"jax_jit_test_x32",
"jax_jit_test_x64",
],
srcs = ["jax_jit_test.py"],
main = "jax_jit_test.py",
)
py_test(

View File

@ -19,7 +19,6 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api
from jax._src import abstract_arrays
from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
@ -27,15 +26,7 @@ from jax._src import test_util as jtu
from jax.config import config
import numpy as np
# It covers all JAX numpy types types except bfloat16 and numpy array.
# TODO(jblespiau): Add support for float0 in the C++ path.
_EXCLUDED_TYPES = [np.ndarray]
_SCALAR_NUMPY_TYPES = [
x for x in abstract_arrays.array_types if x not in _EXCLUDED_TYPES
]
config.parse_flags_with_absl()
def _cpp_device_put(value, device):
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
@ -52,7 +43,7 @@ class JaxJitTest(jtu.JaxTestCase):
def test_device_put_on_numpy_scalars(self, device_put_function):
device = jax.devices()[0]
for dtype in _SCALAR_NUMPY_TYPES:
for dtype in jtu.supported_dtypes():
value = dtype(0)
output_buffer = device_put_function(value, device=device)
@ -66,7 +57,7 @@ class JaxJitTest(jtu.JaxTestCase):
def test_device_put_on_numpy_arrays(self, device_put_function):
device = jax.devices()[0]
for dtype in _SCALAR_NUMPY_TYPES:
for dtype in jtu.supported_dtypes():
value = np.zeros((3, 4), dtype=dtype)
output_buffer = device_put_function(value, device=device)
@ -135,10 +126,12 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex
res = np.asarray(_cpp_device_put(1 + 1j, device))
self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
if not (config.x64_enabled and jtu.device_under_test() == "tpu"):
# No TPU support for complex128.
res = np.asarray(_cpp_device_put(1 + 1j, device))
self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
def test_convert_int_overflow(self):
with self.assertRaisesRegex(
@ -151,7 +144,7 @@ class JaxJitTest(jtu.JaxTestCase):
jax_enable_x64 = config.x64_enabled
# 1. Numpy scalar types
for dtype in _SCALAR_NUMPY_TYPES:
for dtype in jtu.supported_dtypes():
value = dtype(0)
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
@ -160,7 +153,7 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertFalse(signature.weak_type)
# 2. Numpy arrays
for dtype in _SCALAR_NUMPY_TYPES:
for dtype in jtu.supported_dtypes():
value = np.zeros((3, 4), dtype=dtype)
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
@ -194,11 +187,13 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# Complex
signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
self.assertEqual(signature.dtype, complex_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
if not (jax_enable_x64 and jtu.device_under_test() == "tpu"):
# No TPU support for complex128.
signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
self.assertEqual(signature.dtype, complex_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
def test_signature_support(self):
jit = partial(api._jit, True)
@ -210,5 +205,4 @@ class JaxJitTest(jtu.JaxTestCase):
if __name__ == "__main__":
jax.config.config_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())