mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
This commit is contained in:
parent
78678ee9e1
commit
c2b15a1eb8
@ -201,6 +201,15 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "aot_test",
|
||||
srcs = ["aot_test.py"],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "image_test",
|
||||
srcs = ["image_test.py"],
|
||||
|
102
tests/aot_test.py
Normal file
102
tests/aot_test.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
serialize, deserialize_and_load)
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import pytest
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
flags_str = prev_xla_flags or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in flags_str:
|
||||
os.environ["XLA_FLAGS"] = (flags_str +
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xb.get_backend.cache_clear()
|
||||
|
||||
# Reset to previous configuration in case other test modules will be run.
|
||||
def tearDownModule():
|
||||
if prev_xla_flags is None:
|
||||
del os.environ["XLA_FLAGS"]
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xb.get_backend.cache_clear()
|
||||
|
||||
|
||||
class JaxAotTest(jtu.JaxTestCase):
|
||||
|
||||
def check_for_compile_options(self):
|
||||
example_exe = jax.jit(lambda x: x * x).lower(
|
||||
core.ShapedArray(
|
||||
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
|
||||
|
||||
# Skip if CompileOptions is not available. This is true on
|
||||
# CPU/GPU/Cloud TPU for now.
|
||||
try:
|
||||
example_exe.compile_options()
|
||||
except Exception as e:
|
||||
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
|
||||
raise unittest.SkipTest('Serialization not supported')
|
||||
raise e
|
||||
|
||||
def test_pickle_pjit_lower(self):
|
||||
self.check_for_compile_options()
|
||||
|
||||
def fun(x):
|
||||
return x * x
|
||||
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ('data',)):
|
||||
lowered = pjit(
|
||||
fun, in_shardings=P('data'), out_shardings=P(None, 'data')
|
||||
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
|
||||
def verify_serialization(lowered):
|
||||
serialized, in_tree, out_tree = serialize(lowered.compile())
|
||||
compiled = deserialize_and_load(serialized, in_tree, out_tree)
|
||||
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
|
||||
|
||||
verify_serialization(lowered)
|
||||
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
|
||||
verify_serialization(
|
||||
jax.pmap(lambda x: x * x).lower(
|
||||
np.zeros((len(jax.devices()), 4), dtype=np.float32)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -16,7 +16,6 @@
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -31,8 +30,6 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.util import safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
serialize, deserialize_and_load)
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
@ -104,7 +101,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# There are more slices but for convenient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
|
||||
(2, 1),
|
||||
@ -1018,38 +1015,6 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
|
||||
self.assertArraysEqual(y, y_ref1)
|
||||
|
||||
def test_pickle_pjit_lower(self):
|
||||
example_exe = jax.jit(lambda x: x * x).lower(
|
||||
core.ShapedArray(
|
||||
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
|
||||
|
||||
# Skip if CompileOptions is not available. This is true on
|
||||
# CPU/GPU/Cloud TPU for now.
|
||||
try:
|
||||
example_exe.compile_options()
|
||||
except Exception as e:
|
||||
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
|
||||
raise unittest.SkipTest('Serialization not supported')
|
||||
raise e
|
||||
|
||||
def fun(x):
|
||||
return x * x
|
||||
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ('data',)):
|
||||
lowered = pjit(
|
||||
fun, in_shardings=P('data'), out_shardings=P(None, 'data')
|
||||
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
|
||||
def verify_serialization(lowered):
|
||||
serialized, in_tree, out_tree = serialize(lowered.compile())
|
||||
compiled = deserialize_and_load(serialized, in_tree, out_tree)
|
||||
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
|
||||
|
||||
verify_serialization(lowered)
|
||||
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
|
||||
verify_serialization(
|
||||
jax.pmap(lambda x: x * x).lower(
|
||||
np.zeros((len(jax.devices()), 4), dtype=np.float32)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user