mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
#sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to: - pure callbacks - export - shard alike Note that `layout_test` is broken on TPU, leaving a comment saying to enable it. Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function. PiperOrigin-RevId: 691496997
This commit is contained in:
parent
32bf19ac6f
commit
44158ab0e4
@ -25,6 +25,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import array
|
||||
from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding
|
||||
@ -375,6 +376,8 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.product(input_dtype=[jnp.int4, jnp.int8])
|
||||
def test_checkpointing_with_int4(self, input_dtype):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT")
|
||||
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True)
|
||||
global_input_shape = (8, 2)
|
||||
num = math.prod(global_input_shape)
|
||||
@ -580,6 +583,8 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data, np_inp[s.index])
|
||||
|
||||
def test_deserialization_with_int4(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT")
|
||||
if jtu.test_device_matches(['gpu']):
|
||||
self.skipTest("Fails on GPU. Enable after it's fixed")
|
||||
dtype = jnp.int4
|
||||
|
29
tests/BUILD
29
tests/BUILD
@ -225,9 +225,7 @@ jax_multiplatform_test(
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v5p_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
"cpu_shardy",
|
||||
"gpu_2gpu_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v5e_4x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
@ -246,10 +244,8 @@ jax_multiplatform_test(
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
enable_configs = [
|
||||
"cpu_shardy",
|
||||
"gpu_2gpu_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v4_2x2_shardy",
|
||||
"tpu_v3_2x2",
|
||||
"gpu_2gpu",
|
||||
],
|
||||
@ -264,6 +260,7 @@ jax_multiplatform_test(
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(b/355263220): enable on TPU once layouts is supported with Shardy.
|
||||
jax_multiplatform_test(
|
||||
name = "layout_test",
|
||||
srcs = ["layout_test.py"],
|
||||
@ -279,6 +276,9 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "shard_alike_test",
|
||||
srcs = ["shard_alike_test.py"],
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/355263220): enable once shard_alike is supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
@ -309,6 +309,9 @@ jax_multiplatform_test(
|
||||
name = "mock_gpu_test",
|
||||
srcs = ["mock_gpu_test.py"],
|
||||
enable_backends = ["gpu"],
|
||||
enable_configs = [
|
||||
"gpu_2gpu_shardy",
|
||||
],
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
],
|
||||
@ -997,6 +1000,9 @@ jax_multiplatform_test(
|
||||
"gpu": ["--jax_num_generated_cases=40"],
|
||||
"tpu": ["--jax_num_generated_cases=40"],
|
||||
},
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/376475853): array values mismatch, need to fix and re-enable.
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 50,
|
||||
"gpu": 50,
|
||||
@ -1234,6 +1240,9 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "debugging_primitives_test",
|
||||
srcs = ["debugging_primitives_test.py"],
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"gpu_h100",
|
||||
@ -1249,6 +1258,9 @@ jax_multiplatform_test(
|
||||
backend_tags = {
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
@ -1263,6 +1275,9 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "debugger_test",
|
||||
srcs = ["debugger_test.py"],
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"gpu_h100",
|
||||
@ -1313,10 +1328,8 @@ jax_multiplatform_test(
|
||||
name = "shard_map_test",
|
||||
srcs = ["shard_map_test.py"],
|
||||
enable_configs = [
|
||||
"cpu_shardy",
|
||||
"gpu_2gpu_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v4_2x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 50,
|
||||
@ -1405,6 +1418,9 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "export_test",
|
||||
srcs = ["export_test.py"],
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/355263220): enable once export is supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
],
|
||||
@ -1442,6 +1458,7 @@ jax_multiplatform_test(
|
||||
disable_configs = [
|
||||
"gpu_a100", # TODO(b/269593297): matmul precision issues
|
||||
"gpu_h100", # Scarce resources.
|
||||
"cpu_shardy", # TODO(b/355263220): enable once export is supported.
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
|
@ -420,6 +420,8 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
|
||||
|
||||
def test_persistent_cache_miss_logging_with_explain(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
with (config.explain_cache_misses(True),
|
||||
config.compilation_cache_dir("jax-cache")):
|
||||
|
||||
@ -464,6 +466,8 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
|
||||
def test_persistent_cache_miss_logging_with_no_explain(self):
|
||||
# test that cache failure messages do not get logged in WARNING
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
with (config.explain_cache_misses(False),
|
||||
config.compilation_cache_dir("jax-cache")):
|
||||
# omitting writing to cache because compilation is too fast
|
||||
|
@ -17,6 +17,7 @@ import math
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
@ -58,10 +59,16 @@ class MockGPUTest(jtu.JaxTestCase):
|
||||
hlo = f_lowered.compiler_ir()
|
||||
|
||||
mocked_count = NUM_SHARDS * jax.local_device_count()
|
||||
self.assertIn(
|
||||
f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"',
|
||||
str(hlo)
|
||||
)
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn(
|
||||
'sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}',
|
||||
str(hlo)
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"',
|
||||
str(hlo)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -29,6 +29,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
@ -1241,6 +1242,8 @@ class OpsTest(PallasBaseTest):
|
||||
"plgpu.TritonCompilerParams unavailable on Windows",
|
||||
)
|
||||
def test_debug_print(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Not supported on TPU")
|
||||
|
||||
@ -1923,6 +1926,8 @@ class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
||||
def test_debug_print(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
|
||||
|
@ -1235,7 +1235,11 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
|
||||
if len(jax.devices()) > 1:
|
||||
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
|
||||
else:
|
||||
# When devices == 1, the `sdy.manual_computation` is inlined.
|
||||
self.assertEqual(0, hlo_str.count('sdy.manual_computation'))
|
||||
else:
|
||||
self.assertIn('call @shmap_body', hlo_str)
|
||||
self.assertIn('call @shmap_body_0', hlo_str)
|
||||
|
Loading…
x
Reference in New Issue
Block a user