mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
canonicalize closed over values if **atleast** 1 mesh axis is Manual
and **all other mesh axes** are Manual
or Auto
. This would make the canonicalization work properly with shmap partial-auto.
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map. PiperOrigin-RevId: 728956512
This commit is contained in:
parent
b6b319cd06
commit
262aab74f0
@ -1770,7 +1770,12 @@ def canonicalize_value(val):
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh == aval.sharding.mesh:
|
||||
return val
|
||||
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto:
|
||||
# Atleast 1 mesh axis should be Manual and all other axes should be
|
||||
# Manual or Auto to allow casting.
|
||||
# TODO(yashkatariy): Casting to Explicit is not yet allowed. Maybe we need
|
||||
# cast_and_slice_p for it since shape might change?
|
||||
if (cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual and
|
||||
aval.sharding.mesh._are_all_axes_auto):
|
||||
from jax._src.pjit import mesh_cast # pytype: disable=import-error
|
||||
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim)))
|
||||
return val
|
||||
|
@ -6586,6 +6586,22 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
shmap_f() # doesn't crash
|
||||
jax.jit(shmap_f)() # doesn't crash
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={AxisTypes.Auto: ('x', 'y')})
|
||||
def test_shmap_close_over_partial_auto(self, mesh):
|
||||
const = jnp.arange(8)
|
||||
def f():
|
||||
return const * 2
|
||||
|
||||
shmap_f = shard_map(f, mesh=mesh, in_specs=(), out_specs=P('x'),
|
||||
auto=frozenset({'y'}))
|
||||
f = jax.jit(shmap_f)
|
||||
out = f()
|
||||
self.assertArraysEqual(out, jnp.concatenate([const * 2, const * 2]))
|
||||
|
||||
jaxpr = f.trace().jaxpr
|
||||
self.assertIn('mesh_cast', str(jaxpr))
|
||||
|
||||
@jtu.with_user_mesh((2, 1), ('x', 'y'))
|
||||
def test_wsc_error(self, mesh):
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user