1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

fix bug in export_module when no mesh axes are empty for shardy.

If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.

This fix provides an empty mesh, when no mesh axes in dumped module are empty.

PiperOrigin-RevId: 746058506
This commit is contained in:
Kostiantyn Liepieshov 2025-04-10 09:21:15 -07:00 committed by jax authors
parent 6dd576acd5
commit c730bbda74
3 changed files with 9 additions and 6 deletions
jax/_src/export
tests

@ -1444,12 +1444,11 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
'builtin.module(sdy-lift-inlined-meshes)')
pipeline.run(submodule.operation)
# TODO(bartchr): delete this once I have JAX export support multiple meshes.
mesh = None
if shardy_enabled:
sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule))
mesh = mesh_lib.AbstractMesh(
*list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None
mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1])
if sdy_mesh_axes else mesh_lib.empty_abstract_mesh)
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):

@ -1563,9 +1563,6 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "export_test",
srcs = ["export_test.py"],
disable_configs = [
"cpu_shardy", # TODO(b/355263220): enable once export is supported.
],
enable_configs = [
"cpu_shardy",
"gpu_p100x2_shardy",

@ -203,7 +203,14 @@ class JaxExportTest(jtu.JaxTestCase):
f = jnp.sin
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
self.assertAllClose(f(x), exp_f.call(x))
def test_basic_single_device_sharding(self):
device = jax.local_devices()[0]
s = jax.sharding.SingleDeviceSharding(device)
x = np.arange(16, dtype=np.float32).reshape(4, -1)
f = jax.jit(lambda x: x * 2., in_shardings=s, out_shardings=s)
exp_f = get_exported(f)(x)
self.assertAllClose(f(x), exp_f.call(x))
def test_jit_static_arg(self):