mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
fix bug in export_module
when no mesh axes are empty for shardy.
If mesh axes are empty, we are setting mesh as None, resulting in an error in this test. This fix provides an empty mesh, when no mesh axes in dumped module are empty. PiperOrigin-RevId: 746058506
This commit is contained in:
parent
6dd576acd5
commit
c730bbda74
@ -1444,12 +1444,11 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
'builtin.module(sdy-lift-inlined-meshes)')
|
||||
pipeline.run(submodule.operation)
|
||||
|
||||
# TODO(bartchr): delete this once I have JAX export support multiple meshes.
|
||||
mesh = None
|
||||
if shardy_enabled:
|
||||
sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule))
|
||||
mesh = mesh_lib.AbstractMesh(
|
||||
*list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None
|
||||
mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1])
|
||||
if sdy_mesh_axes else mesh_lib.empty_abstract_mesh)
|
||||
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
|
@ -1563,9 +1563,6 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "export_test",
|
||||
srcs = ["export_test.py"],
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/355263220): enable once export is supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"cpu_shardy",
|
||||
"gpu_p100x2_shardy",
|
||||
|
@ -203,7 +203,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f = jnp.sin
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x)
|
||||
self.assertAllClose(f(x), exp_f.call(x))
|
||||
|
||||
def test_basic_single_device_sharding(self):
|
||||
device = jax.local_devices()[0]
|
||||
s = jax.sharding.SingleDeviceSharding(device)
|
||||
x = np.arange(16, dtype=np.float32).reshape(4, -1)
|
||||
f = jax.jit(lambda x: x * 2., in_shardings=s, out_shardings=s)
|
||||
exp_f = get_exported(f)(x)
|
||||
self.assertAllClose(f(x), exp_f.call(x))
|
||||
|
||||
def test_jit_static_arg(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user