rocm_jax/tests/jax_jit_test.py
2021-03-16 15:11:36 -04:00

209 lines
7.6 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import dtypes
from jax import lib as jaxlib
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
import numpy as np
# It covers all JAX numpy types types except bfloat16 and numpy array.
# TODO(jblespiau): Add support for float0 in the C++ path.
_EXCLUDED_TYPES = [np.ndarray]
_SCALAR_NUMPY_TYPES = [
x for x in jax.abstract_arrays.array_types if x not in _EXCLUDED_TYPES
]
def _cpp_device_put(value, device):
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
class JaxJitTest(parameterized.TestCase):
def test_is_float_0(self):
self.assertTrue(
jaxlib.jax_jit._is_float0(np.zeros((5, 5), dtype=jax.float0)))
self.assertFalse(jaxlib.jax_jit._is_float0(np.zeros((5, 5))))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_scalars(self, device_put_function):
device = jax.devices()[0]
for dtype in _SCALAR_NUMPY_TYPES:
value = dtype(0)
output_buffer = device_put_function(value, device=device)
self.assertFalse(output_buffer.aval.weak_type)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_arrays(self, device_put_function):
device = jax.devices()[0]
for dtype in _SCALAR_NUMPY_TYPES:
value = np.zeros((3, 4), dtype=dtype)
output_buffer = device_put_function(value, device=device)
self.assertFalse(output_buffer.aval.weak_type)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
dtype=dtype))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_buffers(self, device_put_function):
device = jax.devices()[0]
jitted_f = jax.jit(lambda x: x + 1)
# We run it twice, to cover `_DeviceArray` and the C++ `Buffer`.
for value in range(2):
buffer = jitted_f(value)
output_buffer = device_put_function(buffer, device=device)
self.assertEqual(output_buffer.dtype, buffer.dtype)
self.assertEqual(output_buffer.aval, buffer.aval)
np.testing.assert_array_equal(output_buffer, np.array(value + 1))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_sharded_device_array(self, device_put_function):
device = jax.devices()[0]
pmaped_f = jax.pmap(lambda x: x + 1)
for _ in range(2):
sda = pmaped_f(np.asarray([[1]]))
output_buffer = device_put_function(sda, device=device)
self.assertNotIsInstance(output_buffer,
jax.interpreters.pxla.ShardedDeviceArray)
self.assertEqual(output_buffer.dtype, sda.dtype)
self.assertEqual(output_buffer.aval, sda.aval)
np.testing.assert_array_equal(output_buffer, np.asarray(sda))
def test_device_put_on_python_scalars(self):
device = jax.devices()[0]
int_type = dtypes.canonicalize_dtype(np.int64)
float_type = dtypes.canonicalize_dtype(np.float64)
complex_type = dtypes.canonicalize_dtype(np.complex128)
# int
res = _cpp_device_put(1, device).to_py()
self.assertEqual(res, 1)
self.assertEqual(res.dtype, int_type)
# We also compare to the Python Jax API, to make sure we have the exact
# same behavior. When Jax removes the flag and removes this feature, this
# test will fail.
self.assertEqual(jnp.asarray(1).dtype, res.dtype)
# float
res = _cpp_device_put(1.0, device).to_py()
self.assertEqual(res, 1.0)
self.assertEqual(res.dtype, float_type)
self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)
# bool
for bool_value in [True, False]:
res = _cpp_device_put(bool_value, device).to_py()
self.assertEqual(res, np.asarray(bool_value))
self.assertEqual(res.dtype, np.bool_)
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex
res = _cpp_device_put(1 + 1j, device).to_py()
self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
def test_convert_int_overflow(self):
with self.assertRaisesRegex(
RuntimeError,
"(Python int too large|Unable to convert Python scalar).*"):
jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0])
def test_arg_signature_of_value(self):
"""Tests the C++ code-path."""
jax_enable_x64 = config.x64_enabled
# 1. Numpy scalar types
for dtype in _SCALAR_NUMPY_TYPES:
value = dtype(0)
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(value).dtype)
self.assertEqual(signature.shape, ())
self.assertFalse(signature.weak_type)
# 2. Numpy arrays
for dtype in _SCALAR_NUMPY_TYPES:
value = np.zeros((3, 4), dtype=dtype)
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(value).dtype)
self.assertEqual(signature.shape, (3, 4))
self.assertFalse(signature.weak_type)
int_type = dtypes.canonicalize_dtype(np.int64)
float_type = dtypes.canonicalize_dtype(np.float64)
complex_type = dtypes.canonicalize_dtype(np.complex128)
# 3. Python scalar types
# int
signature = jaxlib.jax_jit._ArgSignatureOfValue(1, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1).dtype)
self.assertEqual(signature.dtype, int_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# float
signature = jaxlib.jax_jit._ArgSignatureOfValue(1.0, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1.0).dtype)
self.assertEqual(signature.dtype, float_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# bool
for bool_value in [True, False]:
signature = jaxlib.jax_jit._ArgSignatureOfValue(bool_value,
jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(bool_value).dtype)
self.assertEqual(signature.dtype, np.bool_)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# Complex
signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
self.assertEqual(signature.dtype, complex_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
def test_signature_support(self):
def f(a, b, c):
return a + b + c
jitted_f = jax.api._cpp_jit(f)
self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))
if __name__ == "__main__":
jax.config.config_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())