canonicalize closed over values if **atleast** 1 mesh axis is Manual and **all other mesh axes** are Manual or Auto. This would make the canonicalization work properly with shmap partial-auto.

If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.

PiperOrigin-RevId: 728956512
This commit is contained in:
Yash Katariya 2025-02-19 22:18:16 -08:00 committed by jax authors
parent b6b319cd06
commit 262aab74f0
2 changed files with 22 additions and 1 deletions

View File

@ -1770,7 +1770,12 @@ def canonicalize_value(val):
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh == aval.sharding.mesh:
return val
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto:
# Atleast 1 mesh axis should be Manual and all other axes should be
# Manual or Auto to allow casting.
# TODO(yashkatariy): Casting to Explicit is not yet allowed. Maybe we need
# cast_and_slice_p for it since shape might change?
if (cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual and
aval.sharding.mesh._are_all_axes_auto):
from jax._src.pjit import mesh_cast # pytype: disable=import-error
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim)))
return val

View File

@ -6586,6 +6586,22 @@ class ShardingInTypesTest(jtu.JaxTestCase):
shmap_f() # doesn't crash
jax.jit(shmap_f)() # doesn't crash
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={AxisTypes.Auto: ('x', 'y')})
def test_shmap_close_over_partial_auto(self, mesh):
const = jnp.arange(8)
def f():
return const * 2
shmap_f = shard_map(f, mesh=mesh, in_specs=(), out_specs=P('x'),
auto=frozenset({'y'}))
f = jax.jit(shmap_f)
out = f()
self.assertArraysEqual(out, jnp.concatenate([const * 2, const * 2]))
jaxpr = f.trace().jaxpr
self.assertIn('mesh_cast', str(jaxpr))
@jtu.with_user_mesh((2, 1), ('x', 'y'))
def test_wsc_error(self, mesh):
s = NamedSharding(mesh, P('x'))