diff --git a/tests/BUILD b/tests/BUILD index 2af148a58..c6ef3a483 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -201,6 +201,15 @@ jax_test( ], ) +jax_test( + name = "aot_test", + srcs = ["aot_test.py"], + tags = ["multiaccelerator"], + deps = [ + "//jax:experimental", + ], +) + jax_test( name = "image_test", srcs = ["image_test.py"], diff --git a/tests/aot_test.py b/tests/aot_test.py new file mode 100644 index 000000000..de3276e48 --- /dev/null +++ b/tests/aot_test.py @@ -0,0 +1,102 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for GlobalDeviceArray.""" + +import contextlib +import os +import unittest +from absl.testing import absltest +import numpy as np + +import jax +from jax._src import core +from jax._src import test_util as jtu +from jax._src import xla_bridge as xb +from jax.experimental.pjit import pjit +from jax.experimental.serialize_executable import ( + serialize, deserialize_and_load) +from jax.sharding import PartitionSpec as P + +from jax.config import config +config.parse_flags_with_absl() + +prev_xla_flags = None + +with contextlib.suppress(ImportError): + import pytest + pytestmark = pytest.mark.multiaccelerator + + +# Run all tests with 8 CPU devices. +def setUpModule(): + global prev_xla_flags + prev_xla_flags = os.getenv("XLA_FLAGS") + flags_str = prev_xla_flags or "" + # Don't override user-specified device count, or other XLA flags. + if "xla_force_host_platform_device_count" not in flags_str: + os.environ["XLA_FLAGS"] = (flags_str + + " --xla_force_host_platform_device_count=8") + # Clear any cached backends so new CPU backend will pick up the env var. + xb.get_backend.cache_clear() + +# Reset to previous configuration in case other test modules will be run. +def tearDownModule(): + if prev_xla_flags is None: + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = prev_xla_flags + xb.get_backend.cache_clear() + + +class JaxAotTest(jtu.JaxTestCase): + + def check_for_compile_options(self): + example_exe = jax.jit(lambda x: x * x).lower( + core.ShapedArray( + (2, 2), dtype=np.float32)).compile()._executable.xla_executable + + # Skip if CompileOptions is not available. This is true on + # CPU/GPU/Cloud TPU for now. + try: + example_exe.compile_options() + except Exception as e: + if str(e) == 'UNIMPLEMENTED: CompileOptions not available.': + raise unittest.SkipTest('Serialization not supported') + raise e + + def test_pickle_pjit_lower(self): + self.check_for_compile_options() + + def fun(x): + return x * x + + with jax.sharding.Mesh(np.array(jax.devices()), ('data',)): + lowered = pjit( + fun, in_shardings=P('data'), out_shardings=P(None, 'data') + ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) + + def verify_serialization(lowered): + serialized, in_tree, out_tree = serialize(lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) + self.assertEqual(compiled.as_text(), lowered.compile().as_text()) + + verify_serialization(lowered) + verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100))) + verify_serialization( + jax.pmap(lambda x: x * x).lower( + np.zeros((len(jax.devices()), 4), dtype=np.float32))) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_test.py b/tests/array_test.py index 6b1bbfa52..bb4060270 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -16,7 +16,6 @@ import contextlib import math import os -import unittest from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -31,8 +30,6 @@ from jax._src.lib import xla_client as xc from jax._src.util import safe_zip from jax.interpreters import pxla from jax.experimental.pjit import pjit -from jax.experimental.serialize_executable import ( - serialize, deserialize_and_load) from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P from jax._src import array @@ -104,7 +101,7 @@ class JaxArrayTest(jtu.JaxTestCase): @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), - # There are more slices but for convienient purposes, checking for only + # There are more slices but for convenient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))), (2, 1), @@ -1018,38 +1015,6 @@ class RngShardingTest(jtu.JaxTestCase): y_ref1 = f(jax.device_put(x, jax.devices()[0])) self.assertArraysEqual(y, y_ref1) - def test_pickle_pjit_lower(self): - example_exe = jax.jit(lambda x: x * x).lower( - core.ShapedArray( - (2, 2), dtype=np.float32)).compile()._executable.xla_executable - - # Skip if CompileOptions is not available. This is true on - # CPU/GPU/Cloud TPU for now. - try: - example_exe.compile_options() - except Exception as e: - if str(e) == 'UNIMPLEMENTED: CompileOptions not available.': - raise unittest.SkipTest('Serialization not supported') - raise e - - def fun(x): - return x * x - - with jax.sharding.Mesh(np.array(jax.devices()), ('data',)): - lowered = pjit( - fun, in_shardings=P('data'), out_shardings=P(None, 'data') - ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) - - def verify_serialization(lowered): - serialized, in_tree, out_tree = serialize(lowered.compile()) - compiled = deserialize_and_load(serialized, in_tree, out_tree) - self.assertEqual(compiled.as_text(), lowered.compile().as_text()) - - verify_serialization(lowered) - verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100))) - verify_serialization( - jax.pmap(lambda x: x * x).lower( - np.zeros((len(jax.devices()), 4), dtype=np.float32))) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())