mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Try to match out_spec with in_spec if both shardings are full auto and they are equivalent to each other. This is because of backwards compatibility reasons where tests expect the in and out shardings to match.
PiperOrigin-RevId: 721470917
This commit is contained in:
parent
2e40549c38
commit
f4e2c6c34c
@ -2159,7 +2159,10 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
|
||||
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
|
||||
donated_invars, out_shardings, out_layouts)
|
||||
|
||||
def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
|
||||
out_mem_kinds):
|
||||
if device_assignment is None:
|
||||
return shardings
|
||||
if len(device_assignment) == 1:
|
||||
return shardings
|
||||
|
||||
@ -2173,7 +2176,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
axis_types=abstract_mesh.axis_types)
|
||||
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
for s, a, mem_kind in zip(shardings, avals, out_mem_kinds):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
if a.sharding.mesh.empty:
|
||||
out.append(s)
|
||||
@ -2182,7 +2185,8 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
for sp in a.sharding.spec])
|
||||
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
|
||||
out.append(NamedSharding(
|
||||
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
|
||||
_abstract_to_concrete_mesh(a.sharding.mesh), spec,
|
||||
memory_kind=mem_kind))
|
||||
else:
|
||||
out.append(s)
|
||||
return tuple(out)
|
||||
@ -2260,10 +2264,6 @@ def lower_sharding_computation(
|
||||
devices_from_context)
|
||||
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
out_shardings = _concretize_abstract_out_shardings(
|
||||
out_shardings, global_out_avals, device_assignment)
|
||||
|
||||
# TODO(parkers): One _raw_platform has been unified with platform,
|
||||
# change this back to just read platform.
|
||||
platforms = lowering_platforms or (
|
||||
@ -2313,6 +2313,11 @@ def lower_sharding_computation(
|
||||
propagated_out_mem_kinds = get_out_memory_kinds_via_propagation(
|
||||
closed_jaxpr, in_shardings)
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
out_shardings = _concretize_abstract_out_shardings(
|
||||
out_shardings, global_out_avals, device_assignment,
|
||||
propagated_out_mem_kinds)
|
||||
|
||||
# 2. Build up the HLO
|
||||
|
||||
abstract_mesh = None
|
||||
@ -2563,10 +2568,6 @@ def _get_out_sharding_from_orig_sharding(
|
||||
for o, out_aval in safe_zip(out_shardings, out_avals):
|
||||
if (isinstance(o, sharding_impls.GSPMDSharding) and
|
||||
out_aval is not core.abstract_token):
|
||||
# Only return the same input sharding object if the OpShardings and
|
||||
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
|
||||
# replicated then, it doesn't encode the ndim in it. The devices
|
||||
# will be the same at this point because those checks happen before.
|
||||
if (orig_aval is not None and out_aval is not None and
|
||||
out_aval.ndim == orig_aval.ndim
|
||||
and sharding_impls.are_op_shardings_equal(
|
||||
@ -2582,9 +2583,41 @@ def _get_out_sharding_from_orig_sharding(
|
||||
out.append(o)
|
||||
return out
|
||||
|
||||
|
||||
def try_matching_out_with_in_spec_for_all_auto(
|
||||
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
|
||||
recover_in_s, recover_in_aval = None, None
|
||||
for in_s, in_aval in safe_zip(in_shardings, in_avals):
|
||||
if in_s is not None and type(in_s) in _orig_out_sharding_handlers:
|
||||
recover_in_s, recover_in_aval = in_s, in_aval
|
||||
break
|
||||
if recover_in_s is None:
|
||||
return new_out_shardings
|
||||
|
||||
res = []
|
||||
for orig_out_s, out_s, out_aval in safe_zip(
|
||||
orig_out_shardings, new_out_shardings, out_avals):
|
||||
if (out_aval is not core.abstract_token and
|
||||
mlir.all_unconstrained(orig_out_s, out_aval) and
|
||||
isinstance(orig_out_s, NamedSharding) and
|
||||
isinstance(out_s, NamedSharding) and
|
||||
orig_out_s.mesh._are_all_axes_auto and out_s.mesh._are_all_axes_auto and
|
||||
out_aval.ndim == recover_in_aval.ndim and
|
||||
out_s.is_equivalent_to(recover_in_s, out_aval.ndim)):
|
||||
res.append(out_s.with_spec(recover_in_s.spec))
|
||||
else:
|
||||
res.append(out_s)
|
||||
return res
|
||||
|
||||
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals,
|
||||
intermediate_shardings=None, context_mesh: Mesh | None = None):
|
||||
intermediate_shardings=None, context_mesh: Mesh | None = None,
|
||||
orig_out_shardings=None):
|
||||
if orig_out_shardings is not None:
|
||||
new_shardings = try_matching_out_with_in_spec_for_all_auto(
|
||||
orig_out_shardings, new_shardings, new_avals, old_shardings, old_avals)
|
||||
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
@ -2936,6 +2969,8 @@ class UnloadedMeshExecutable:
|
||||
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
|
||||
compiler_options_kvs, pgle_profiler)
|
||||
|
||||
orig_out_shardings = out_shardings
|
||||
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
||||
@ -2972,7 +3007,7 @@ class UnloadedMeshExecutable:
|
||||
|
||||
out_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals,
|
||||
intermediate_shardings, context_mesh)
|
||||
intermediate_shardings, context_mesh, orig_out_shardings)
|
||||
|
||||
in_shardings = finalize_shardings(in_shardings, da)
|
||||
out_shardings = finalize_shardings(out_shardings, da)
|
||||
|
@ -5772,7 +5772,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
return a
|
||||
|
||||
out = f(arr, arr.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
def test_auto_user(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
|
Loading…
x
Reference in New Issue
Block a user