Try to match out_spec with in_spec if both shardings are full auto and they are equivalent to each other. This is because of backwards compatibility reasons where tests expect the in and out shardings to match.

PiperOrigin-RevId: 721470917
This commit is contained in:
Yash Katariya 2025-01-30 11:59:22 -08:00 committed by jax authors
parent 2e40549c38
commit f4e2c6c34c
2 changed files with 49 additions and 14 deletions

View File

@ -2159,7 +2159,10 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts)
def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
out_mem_kinds):
if device_assignment is None:
return shardings
if len(device_assignment) == 1:
return shardings
@ -2173,7 +2176,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
axis_types=abstract_mesh.axis_types)
out = []
for s, a in zip(shardings, avals):
for s, a, mem_kind in zip(shardings, avals, out_mem_kinds):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
if a.sharding.mesh.empty:
out.append(s)
@ -2182,7 +2185,8 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
for sp in a.sharding.spec])
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
out.append(NamedSharding(
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
_abstract_to_concrete_mesh(a.sharding.mesh), spec,
memory_kind=mem_kind))
else:
out.append(s)
return tuple(out)
@ -2260,10 +2264,6 @@ def lower_sharding_computation(
devices_from_context)
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
if config.sharding_in_types.value:
out_shardings = _concretize_abstract_out_shardings(
out_shardings, global_out_avals, device_assignment)
# TODO(parkers): One _raw_platform has been unified with platform,
# change this back to just read platform.
platforms = lowering_platforms or (
@ -2313,6 +2313,11 @@ def lower_sharding_computation(
propagated_out_mem_kinds = get_out_memory_kinds_via_propagation(
closed_jaxpr, in_shardings)
if config.sharding_in_types.value:
out_shardings = _concretize_abstract_out_shardings(
out_shardings, global_out_avals, device_assignment,
propagated_out_mem_kinds)
# 2. Build up the HLO
abstract_mesh = None
@ -2563,10 +2568,6 @@ def _get_out_sharding_from_orig_sharding(
for o, out_aval in safe_zip(out_shardings, out_avals):
if (isinstance(o, sharding_impls.GSPMDSharding) and
out_aval is not core.abstract_token):
# Only return the same input sharding object if the OpShardings and
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
# replicated then, it doesn't encode the ndim in it. The devices
# will be the same at this point because those checks happen before.
if (orig_aval is not None and out_aval is not None and
out_aval.ndim == orig_aval.ndim
and sharding_impls.are_op_shardings_equal(
@ -2582,9 +2583,41 @@ def _get_out_sharding_from_orig_sharding(
out.append(o)
return out
def try_matching_out_with_in_spec_for_all_auto(
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
recover_in_s, recover_in_aval = None, None
for in_s, in_aval in safe_zip(in_shardings, in_avals):
if in_s is not None and type(in_s) in _orig_out_sharding_handlers:
recover_in_s, recover_in_aval = in_s, in_aval
break
if recover_in_s is None:
return new_out_shardings
res = []
for orig_out_s, out_s, out_aval in safe_zip(
orig_out_shardings, new_out_shardings, out_avals):
if (out_aval is not core.abstract_token and
mlir.all_unconstrained(orig_out_s, out_aval) and
isinstance(orig_out_s, NamedSharding) and
isinstance(out_s, NamedSharding) and
orig_out_s.mesh._are_all_axes_auto and out_s.mesh._are_all_axes_auto and
out_aval.ndim == recover_in_aval.ndim and
out_s.is_equivalent_to(recover_in_s, out_aval.ndim)):
res.append(out_s.with_spec(recover_in_s.spec))
else:
res.append(out_s)
return res
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None, context_mesh: Mesh | None = None):
intermediate_shardings=None, context_mesh: Mesh | None = None,
orig_out_shardings=None):
if orig_out_shardings is not None:
new_shardings = try_matching_out_with_in_spec_for_all_auto(
orig_out_shardings, new_shardings, new_avals, old_shardings, old_avals)
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
@ -2936,6 +2969,8 @@ class UnloadedMeshExecutable:
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_kvs, pgle_profiler)
orig_out_shardings = out_shardings
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
@ -2972,7 +3007,7 @@ class UnloadedMeshExecutable:
out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings, context_mesh)
intermediate_shardings, context_mesh, orig_out_shardings)
in_shardings = finalize_shardings(in_shardings, da)
out_shardings = finalize_shardings(out_shardings, da)

View File

@ -5772,7 +5772,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
return a
out = f(arr, arr.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
def test_auto_user(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),