rocm_jax/tests/jax_jit_test.py
Jean-Baptiste Lespiau bdd65453b4
Add more features to the C++ jax.jit. (#4169)
This mainly follows https://github.com/google/jax/pull/4089 by adding:

- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.

I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)

See:

- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
 cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
2020-09-01 10:34:47 +03:00

98 lines
3.0 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lib as jaxlib
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import flags
from jax.lib import version
from jax.lib import xla_bridge
import numpy as np
FLAGS = flags.FLAGS
class JaxJitTest(parameterized.TestCase):
def test_convert_scalars(self):
# TODO(jblespiau): Remove when the version is out.
if jaxlib.version < (0, 1, 53):
return
jax_jit = jaxlib.jax_jit
jax_enable_x64 = FLAGS.jax_enable_x64
if jax_enable_x64:
int_type = np.int64
float_type = np.float64
complex_type = np.complex128
else:
int_type = np.int32
float_type = np.float32
complex_type = np.complex64
# int
res = jax_jit._ScalarToBuffer(1, jax_enable_x64,
xla_bridge.get_backend()).to_py()
self.assertEqual(res, 1)
self.assertEqual(res.dtype, int_type)
# We also compare to the Python Jax API, to make sure we have the exact
# same behavior. When Jax removes the flag and removes this feature, this
# test will fail.
self.assertEqual(jnp.asarray(1).dtype, res.dtype)
# float
res = jax_jit._ScalarToBuffer(1.0, jax_enable_x64,
xla_bridge.get_backend()).to_py()
self.assertEqual(res, 1.0)
self.assertEqual(res.dtype, float_type)
self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)
# bool
for bool_value in [True, False]:
res = jax_jit._ScalarToBuffer(bool_value, jax_enable_x64,
xla_bridge.get_backend()).to_py()
self.assertEqual(res, np.asarray(bool_value))
self.assertEqual(res.dtype, np.bool)
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex
res = jax_jit._ScalarToBuffer(1 + 1j, jax_enable_x64,
xla_bridge.get_backend()).to_py()
self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
def test_signature_support(self):
if version < (0, 1, 54):
return
def f(a, b, c):
return a + b + c
jitted_f = jax.api._cpp_jit(f)
self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))
if __name__ == '__main__':
jax.config.config_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())