From 262aab74f0b0e00efae4e631946552fee014b35a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Feb 2025 22:18:16 -0800 Subject: [PATCH] canonicalize closed over values if **atleast** 1 mesh axis is `Manual` and **all other mesh axes** are `Manual` or `Auto`. This would make the canonicalization work properly with shmap partial-auto. If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map. PiperOrigin-RevId: 728956512 --- jax/_src/core.py | 7 ++++++- tests/pjit_test.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index f14763914..d8f91789b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1770,7 +1770,12 @@ def canonicalize_value(val): cur_mesh = mesh_lib.get_abstract_mesh() if cur_mesh == aval.sharding.mesh: return val - if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: + # Atleast 1 mesh axis should be Manual and all other axes should be + # Manual or Auto to allow casting. + # TODO(yashkatariy): Casting to Explicit is not yet allowed. Maybe we need + # cast_and_slice_p for it since shape might change? + if (cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual and + aval.sharding.mesh._are_all_axes_auto): from jax._src.pjit import mesh_cast # pytype: disable=import-error return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) return val diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 704a2d011..33d38d341 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6586,6 +6586,22 @@ class ShardingInTypesTest(jtu.JaxTestCase): shmap_f() # doesn't crash jax.jit(shmap_f)() # doesn't crash + @jtu.with_user_mesh((2, 2), ('x', 'y'), + axis_types={AxisTypes.Auto: ('x', 'y')}) + def test_shmap_close_over_partial_auto(self, mesh): + const = jnp.arange(8) + def f(): + return const * 2 + + shmap_f = shard_map(f, mesh=mesh, in_specs=(), out_specs=P('x'), + auto=frozenset({'y'})) + f = jax.jit(shmap_f) + out = f() + self.assertArraysEqual(out, jnp.concatenate([const * 2, const * 2])) + + jaxpr = f.trace().jaxpr + self.assertIn('mesh_cast', str(jaxpr)) + @jtu.with_user_mesh((2, 1), ('x', 'y')) def test_wsc_error(self, mesh): s = NamedSharding(mesh, P('x'))