#sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.

Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
This commit is contained in:
Bart Chrzaszcz 2024-10-30 11:39:50 -07:00 committed by jax authors
parent 32bf19ac6f
commit 44158ab0e4
6 changed files with 53 additions and 11 deletions

View File

@ -25,6 +25,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src import array
from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding
@ -375,6 +376,8 @@ class CheckpointTest(jtu.JaxTestCase):
@parameterized.product(input_dtype=[jnp.int4, jnp.int8])
def test_checkpointing_with_int4(self, input_dtype):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT")
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True)
global_input_shape = (8, 2)
num = math.prod(global_input_shape)
@ -580,6 +583,8 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertArraysEqual(s.data, np_inp[s.index])
def test_deserialization_with_int4(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT")
if jtu.test_device_matches(['gpu']):
self.skipTest("Fails on GPU. Enable after it's fixed")
dtype = jnp.int4

View File

@ -225,9 +225,7 @@ jax_multiplatform_test(
"tpu_v4_2x2",
"tpu_v5p_2x2",
"tpu_v5e_4x2",
"cpu_shardy",
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v5e_4x2_shardy",
],
shard_count = {
@ -246,10 +244,8 @@ jax_multiplatform_test(
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"cpu_shardy",
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
"tpu_v3_2x2",
"gpu_2gpu",
],
@ -264,6 +260,7 @@ jax_multiplatform_test(
],
)
# TODO(b/355263220): enable on TPU once layouts is supported with Shardy.
jax_multiplatform_test(
name = "layout_test",
srcs = ["layout_test.py"],
@ -279,6 +276,9 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "shard_alike_test",
srcs = ["shard_alike_test.py"],
disable_configs = [
"cpu_shardy", # TODO(b/355263220): enable once shard_alike is supported.
],
enable_configs = [
"tpu_v3_2x2",
"tpu_v5e_4x2",
@ -309,6 +309,9 @@ jax_multiplatform_test(
name = "mock_gpu_test",
srcs = ["mock_gpu_test.py"],
enable_backends = ["gpu"],
enable_configs = [
"gpu_2gpu_shardy",
],
tags = [
"config-cuda-only",
],
@ -997,6 +1000,9 @@ jax_multiplatform_test(
"gpu": ["--jax_num_generated_cases=40"],
"tpu": ["--jax_num_generated_cases=40"],
},
disable_configs = [
"cpu_shardy", # TODO(b/376475853): array values mismatch, need to fix and re-enable.
],
shard_count = {
"cpu": 50,
"gpu": 50,
@ -1234,6 +1240,9 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "debugging_primitives_test",
srcs = ["debugging_primitives_test.py"],
disable_configs = [
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
],
enable_configs = [
"cpu",
"gpu_h100",
@ -1249,6 +1258,9 @@ jax_multiplatform_test(
backend_tags = {
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
disable_configs = [
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
],
enable_configs = [
"tpu_v2_1x1",
"tpu_v3_2x2",
@ -1263,6 +1275,9 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "debugger_test",
srcs = ["debugger_test.py"],
disable_configs = [
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
],
enable_configs = [
"cpu",
"gpu_h100",
@ -1313,10 +1328,8 @@ jax_multiplatform_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],
enable_configs = [
"cpu_shardy",
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
],
shard_count = {
"cpu": 50,
@ -1405,6 +1418,9 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "export_test",
srcs = ["export_test.py"],
disable_configs = [
"cpu_shardy", # TODO(b/355263220): enable once export is supported.
],
enable_configs = [
"tpu_v3_2x2",
],
@ -1442,6 +1458,7 @@ jax_multiplatform_test(
disable_configs = [
"gpu_a100", # TODO(b/269593297): matmul precision issues
"gpu_h100", # Scarce resources.
"cpu_shardy", # TODO(b/355263220): enable once export is supported.
],
shard_count = {
"cpu": 40,

View File

@ -420,6 +420,8 @@ class CompilationCacheTest(CompilationCacheTestCase):
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
def test_persistent_cache_miss_logging_with_explain(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
with (config.explain_cache_misses(True),
config.compilation_cache_dir("jax-cache")):
@ -464,6 +466,8 @@ class CompilationCacheTest(CompilationCacheTestCase):
def test_persistent_cache_miss_logging_with_no_explain(self):
# test that cache failure messages do not get logged in WARNING
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
with (config.explain_cache_misses(False),
config.compilation_cache_dir("jax-cache")):
# omitting writing to cache because compilation is too fast

View File

@ -17,6 +17,7 @@ import math
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.sharding import NamedSharding
@ -58,10 +59,16 @@ class MockGPUTest(jtu.JaxTestCase):
hlo = f_lowered.compiler_ir()
mocked_count = NUM_SHARDS * jax.local_device_count()
self.assertIn(
f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"',
str(hlo)
)
if config.use_shardy_partitioner.value:
self.assertIn(
'sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}',
str(hlo)
)
else:
self.assertIn(
f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"',
str(hlo)
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -29,6 +29,7 @@ import jax
import jax.numpy as jnp
from jax import lax
from jax import random
from jax._src import config
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
@ -1241,6 +1242,8 @@ class OpsTest(PallasBaseTest):
"plgpu.TritonCompilerParams unavailable on Windows",
)
def test_debug_print(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")
@ -1923,6 +1926,8 @@ class OpsInterpretTest(OpsTest):
INTERPRET = True
def test_debug_print(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),

View File

@ -1235,7 +1235,11 @@ class ShardMapTest(jtu.JaxTestCase):
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
if config.use_shardy_partitioner.value:
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
if len(jax.devices()) > 1:
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
else:
# When devices == 1, the `sdy.manual_computation` is inlined.
self.assertEqual(0, hlo_str.count('sdy.manual_computation'))
else:
self.assertIn('call @shmap_body', hlo_str)
self.assertIn('call @shmap_body_0', hlo_str)