Use the mesh of out_aval when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.

Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.

PiperOrigin-RevId: 729307115
This commit is contained in:
Yash Katariya 2025-02-20 17:12:40 -08:00 committed by jax authors
parent bbc4fa7125
commit 250e2ee7da
3 changed files with 65 additions and 76 deletions

View File

@ -1030,6 +1030,8 @@ def _to_physical_op_sharding(
) -> xc.OpSharding | SdyArraySharding | None:
if sharding is None:
return None
if all_unconstrained(sharding, aval):
return None
if isinstance(sharding, AUTO):
if config.use_shardy_partitioner.value:
return sharding._to_sdy_sharding(aval.ndim) # type: ignore
@ -1071,10 +1073,8 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
def contains_unconstrained(s):
return (
isinstance(s, NamedSharding)
and PartitionSpec.UNCONSTRAINED in s._parsed_pspec
)
return (isinstance(s, NamedSharding)
and PartitionSpec.UNCONSTRAINED in s._parsed_pspec)
def all_unconstrained(s, aval):
@ -1084,12 +1084,19 @@ def all_unconstrained(s, aval):
return all(p is PartitionSpec.UNCONSTRAINED for p in s._parsed_pspec)
return False
def _get_unconstrained_dimensions(s, aval):
class UnconstrainedVariants(NamedTuple):
contains_unconstrained: bool
all_unconstrained: bool
unconstrained_dims: set[int] | None
def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants:
us = contains_unconstrained(s)
return (
us, all_unconstrained(s, aval),
({i for i, p in enumerate(s._parsed_pspec)
if p is PartitionSpec.UNCONSTRAINED} if us else None))
unconstrained_dims = ({i for i, p in enumerate(s._parsed_pspec)
if p is PartitionSpec.UNCONSTRAINED} if us else None)
return UnconstrainedVariants(
contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval),
unconstrained_dims=unconstrained_dims)
def lower_jaxpr_to_module(
module_name: str,
@ -1511,13 +1518,13 @@ def lower_jaxpr_to_fun(
for is_donated, types in zip(xla_donated_args, input_types)])
ir_result_shardings = None
unconstrained_shardings = None
unconstrained_variants = None
if result_shardings is not None:
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
for a, s, types in zip(output_avals, result_shardings, output_types)])
unconstrained_shardings = util.flatten(
[[_get_unconstrained_dimensions(s, a)] * len_ir_types(types)
unconstrained_variants = util.flatten(
[[_get_unconstrained_variants(s, a)] * len_ir_types(types)
for a, s, types in zip(output_avals, result_shardings, output_types)])
ir_result_memory_kinds = None
@ -1633,9 +1640,9 @@ def lower_jaxpr_to_fun(
attrs['jax.result_info'] = ir.StringAttr.get(name_)
if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
unconstrained_shardings): # type: ignore
if sharding is not None and not us[0]:
for attrs, sharding, uv in zip(result_attrs, ir_result_shardings,
unconstrained_variants): # type: ignore
if sharding is not None and not uv.contains_unconstrained:
if config.use_shardy_partitioner.value:
attrs["sdy.sharding"] = get_sharding_attr(sharding)
else:
@ -1716,13 +1723,15 @@ def lower_jaxpr_to_fun(
if ir_result_shardings is not None:
temp_flat_outputs = []
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_shardings): # type: ignore
if us[0] and not us[1]:
for o, s, o_aval, uv in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_variants): # type: ignore
if (s is not None and uv.contains_unconstrained and
not uv.all_unconstrained):
if config.use_shardy_partitioner.value:
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
temp_flat_outputs.append(wrap_with_sharding_op(
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
entry_lowering_ctx, o, o_aval, s,
unspecified_dims=uv.unconstrained_dims))
else:
temp_flat_outputs.append(o)
flat_outputs = temp_flat_outputs

View File

@ -2157,6 +2157,13 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts)
@lru_cache(maxsize=1024)
def _abstract_to_concrete_mesh(abstract_mesh, device_assignment):
np_dev = np.vectorize(lambda i: device_assignment[i],
otypes=[object])(np.arange(len(device_assignment)))
return Mesh(np_dev.reshape(abstract_mesh.axis_sizes),
abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types)
def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
out_mem_kinds):
if device_assignment is None:
@ -2164,27 +2171,20 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
if len(device_assignment) == 1:
return shardings
np_dev = np.vectorize(lambda i: device_assignment[i],
otypes=[object])(np.arange(len(device_assignment)))
@lru_cache(maxsize=128)
def _abstract_to_concrete_mesh(abstract_mesh):
return Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
axis_types=abstract_mesh.axis_types)
out = []
for s, a, mem_kind in zip(shardings, avals, out_mem_kinds):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
if a.sharding.mesh.empty:
out.append(s)
elif a.sharding.mesh._are_all_axes_auto:
out.append(s)
else:
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
for sp in a.sharding.spec])
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
out.append(NamedSharding(
_abstract_to_concrete_mesh(a.sharding.mesh), spec,
memory_kind=mem_kind))
_abstract_to_concrete_mesh(a.sharding.mesh, device_assignment),
spec, memory_kind=mem_kind))
else:
out.append(s)
return tuple(out)
@ -2534,15 +2534,22 @@ def _get_mesh_pspec_shardings_from_executable(
_orig_out_sharding_handlers = {}
def _gspmd_to_named_sharding(
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
out_s: GSPMDSharding, out_aval, orig_in_s: NamedSharding) -> NamedSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, NamedSharding)
assert isinstance(orig_in_s.mesh, Mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
if (out_aval is not None and not out_aval.sharding.mesh.empty and
out_aval.sharding.mesh._are_all_axes_auto):
mesh = _abstract_to_concrete_mesh(
out_aval.sharding.mesh, out_s._device_assignment)
else:
mesh = orig_in_s.mesh
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh)
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding
def _gspmd_to_positional_sharding(
out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding
) -> PositionalSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, PositionalSharding)
return sharding_impls._op_sharding_to_pos_sharding(
@ -2550,7 +2557,8 @@ def _gspmd_to_positional_sharding(
_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore
def _gspmd_to_single_device_sharding(
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding
) -> SingleDeviceSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, SingleDeviceSharding)
return SingleDeviceSharding(
@ -2565,15 +2573,17 @@ def _get_out_sharding_from_orig_sharding(
for o, out_aval in safe_zip(out_shardings, out_avals):
if (isinstance(o, sharding_impls.GSPMDSharding) and
out_aval is not core.abstract_token):
if (orig_aval is not None and out_aval is not None and
out_aval.ndim == orig_aval.ndim
and sharding_impls.are_op_shardings_equal(
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
and o.memory_kind == orig_in_s.memory_kind):
# TODO(yashkatariya): Remove this condition and ask users to drop into
# explicit mode.
if (orig_aval is not None and out_aval is not None
and out_aval.ndim == orig_aval.ndim
and isinstance(orig_in_s, NamedSharding)
and out_aval.sharding.mesh == orig_in_s.mesh.abstract_mesh
and o.is_equivalent_to(orig_in_s, orig_aval.ndim)):
out.append(orig_in_s)
else:
try:
out.append(orig_handler(o, orig_in_s))
out.append(orig_handler(o, out_aval, orig_in_s))
except:
out.append(o)
else:
@ -2581,40 +2591,9 @@ def _get_out_sharding_from_orig_sharding(
return out
def try_matching_out_with_in_spec_for_all_auto(
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
recover_in_s, recover_in_aval = None, None
for in_s, in_aval in safe_zip(in_shardings, in_avals):
if isinstance(in_s, NamedSharding):
recover_in_s, recover_in_aval = in_s, in_aval
break
if recover_in_s is None:
return new_out_shardings
res = []
for orig_out_s, out_s, out_aval in safe_zip(
orig_out_shardings, new_out_shardings, out_avals):
if (out_aval is not core.abstract_token and
mlir.all_unconstrained(orig_out_s, out_aval) and
isinstance(orig_out_s, NamedSharding) and
isinstance(out_s, NamedSharding) and
orig_out_s.mesh._are_all_axes_auto and out_s.mesh._are_all_axes_auto and
out_aval.ndim == recover_in_aval.ndim and
out_s.is_equivalent_to(recover_in_s, out_aval.ndim)):
res.append(out_s.with_spec(recover_in_s.spec))
else:
res.append(out_s)
return res
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None, context_mesh: Mesh | None = None,
orig_out_shardings=None):
if orig_out_shardings is not None:
new_shardings = try_matching_out_with_in_spec_for_all_auto(
orig_out_shardings, new_shardings, new_avals, old_shardings, old_avals)
intermediate_shardings=None, context_mesh: Mesh | None = None):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
@ -2831,7 +2810,7 @@ def _maybe_get_and_check_out_shardings(
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
try:
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # pytype: disable=wrong-arg-types
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # pytype: disable=wrong-arg-types
except:
new_out_shardings.append(xla_s)
else:
@ -3004,7 +2983,7 @@ class UnloadedMeshExecutable:
out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings, context_mesh, orig_out_shardings)
intermediate_shardings, context_mesh)
in_shardings = finalize_shardings(in_shardings, da)
out_shardings = finalize_shardings(out_shardings, da)

View File

@ -131,7 +131,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
return shard_alike(x, y)[1]
out = f(inp)
self.assertEqual(out.sharding, s)
self.assertTrue(out.sharding.is_equivalent_to(s, out.ndim))
self.assertArraysEqual(out, np_inp)
def test_shard_map(self):
@ -268,7 +268,8 @@ class ShardAlikeTest(jtu.JaxTestCase):
x = jax.device_put(np.arange(8), s)
_, y = shard_alike(x, jnp.arange(8))
self.assertEqual(y.sharding, s)
self.assertTrue(y.sharding.is_equivalent_to(s, y.ndim))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())