mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Change jax_jit_test to be a jax_test() under Bazel that works across backends.
Make it pass under TPU if x64 types are enabled. PiperOrigin-RevId: 476994286
This commit is contained in:
parent
265b39d23f
commit
d63a9442bb
31
tests/BUILD
31
tests/BUILD
@ -257,35 +257,10 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "jax_jit_test_x32",
|
||||
srcs = ["jax_jit_test.py"],
|
||||
main = "jax_jit_test.py",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "jax_jit_test_x64",
|
||||
srcs = ["jax_jit_test.py"],
|
||||
args = ["--jax_enable_x64=true"],
|
||||
main = "jax_jit_test.py",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
test_suite(
|
||||
jax_test(
|
||||
name = "jax_jit_test",
|
||||
tests = [
|
||||
"jax_jit_test_x32",
|
||||
"jax_jit_test_x64",
|
||||
],
|
||||
srcs = ["jax_jit_test.py"],
|
||||
main = "jax_jit_test.py",
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -19,7 +19,6 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import abstract_arrays
|
||||
from jax import dtypes
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
@ -27,15 +26,7 @@ from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
import numpy as np
|
||||
|
||||
|
||||
# It covers all JAX numpy types types except bfloat16 and numpy array.
|
||||
# TODO(jblespiau): Add support for float0 in the C++ path.
|
||||
_EXCLUDED_TYPES = [np.ndarray]
|
||||
|
||||
_SCALAR_NUMPY_TYPES = [
|
||||
x for x in abstract_arrays.array_types if x not in _EXCLUDED_TYPES
|
||||
]
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _cpp_device_put(value, device):
|
||||
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
|
||||
@ -52,7 +43,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
def test_device_put_on_numpy_scalars(self, device_put_function):
|
||||
|
||||
device = jax.devices()[0]
|
||||
for dtype in _SCALAR_NUMPY_TYPES:
|
||||
for dtype in jtu.supported_dtypes():
|
||||
value = dtype(0)
|
||||
|
||||
output_buffer = device_put_function(value, device=device)
|
||||
@ -66,7 +57,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
def test_device_put_on_numpy_arrays(self, device_put_function):
|
||||
|
||||
device = jax.devices()[0]
|
||||
for dtype in _SCALAR_NUMPY_TYPES:
|
||||
for dtype in jtu.supported_dtypes():
|
||||
value = np.zeros((3, 4), dtype=dtype)
|
||||
output_buffer = device_put_function(value, device=device)
|
||||
|
||||
@ -135,10 +126,12 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
|
||||
|
||||
# Complex
|
||||
res = np.asarray(_cpp_device_put(1 + 1j, device))
|
||||
self.assertEqual(res, 1 + 1j)
|
||||
self.assertEqual(res.dtype, complex_type)
|
||||
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
|
||||
if not (config.x64_enabled and jtu.device_under_test() == "tpu"):
|
||||
# No TPU support for complex128.
|
||||
res = np.asarray(_cpp_device_put(1 + 1j, device))
|
||||
self.assertEqual(res, 1 + 1j)
|
||||
self.assertEqual(res.dtype, complex_type)
|
||||
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
|
||||
|
||||
def test_convert_int_overflow(self):
|
||||
with self.assertRaisesRegex(
|
||||
@ -151,7 +144,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
jax_enable_x64 = config.x64_enabled
|
||||
|
||||
# 1. Numpy scalar types
|
||||
for dtype in _SCALAR_NUMPY_TYPES:
|
||||
for dtype in jtu.supported_dtypes():
|
||||
value = dtype(0)
|
||||
|
||||
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
|
||||
@ -160,7 +153,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
self.assertFalse(signature.weak_type)
|
||||
|
||||
# 2. Numpy arrays
|
||||
for dtype in _SCALAR_NUMPY_TYPES:
|
||||
for dtype in jtu.supported_dtypes():
|
||||
value = np.zeros((3, 4), dtype=dtype)
|
||||
|
||||
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
|
||||
@ -194,11 +187,13 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(signature.shape, ())
|
||||
self.assertTrue(signature.weak_type)
|
||||
# Complex
|
||||
signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
|
||||
self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
|
||||
self.assertEqual(signature.dtype, complex_type)
|
||||
self.assertEqual(signature.shape, ())
|
||||
self.assertTrue(signature.weak_type)
|
||||
if not (jax_enable_x64 and jtu.device_under_test() == "tpu"):
|
||||
# No TPU support for complex128.
|
||||
signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
|
||||
self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
|
||||
self.assertEqual(signature.dtype, complex_type)
|
||||
self.assertEqual(signature.shape, ())
|
||||
self.assertTrue(signature.weak_type)
|
||||
|
||||
def test_signature_support(self):
|
||||
jit = partial(api._jit, True)
|
||||
@ -210,5 +205,4 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
jax.config.config_with_absl()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user