mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Use the mesh of out_aval
when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now. PiperOrigin-RevId: 729307115
This commit is contained in:
parent
bbc4fa7125
commit
250e2ee7da
@ -1030,6 +1030,8 @@ def _to_physical_op_sharding(
|
||||
) -> xc.OpSharding | SdyArraySharding | None:
|
||||
if sharding is None:
|
||||
return None
|
||||
if all_unconstrained(sharding, aval):
|
||||
return None
|
||||
if isinstance(sharding, AUTO):
|
||||
if config.use_shardy_partitioner.value:
|
||||
return sharding._to_sdy_sharding(aval.ndim) # type: ignore
|
||||
@ -1071,10 +1073,8 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
|
||||
|
||||
|
||||
def contains_unconstrained(s):
|
||||
return (
|
||||
isinstance(s, NamedSharding)
|
||||
and PartitionSpec.UNCONSTRAINED in s._parsed_pspec
|
||||
)
|
||||
return (isinstance(s, NamedSharding)
|
||||
and PartitionSpec.UNCONSTRAINED in s._parsed_pspec)
|
||||
|
||||
|
||||
def all_unconstrained(s, aval):
|
||||
@ -1084,12 +1084,19 @@ def all_unconstrained(s, aval):
|
||||
return all(p is PartitionSpec.UNCONSTRAINED for p in s._parsed_pspec)
|
||||
return False
|
||||
|
||||
def _get_unconstrained_dimensions(s, aval):
|
||||
class UnconstrainedVariants(NamedTuple):
|
||||
contains_unconstrained: bool
|
||||
all_unconstrained: bool
|
||||
unconstrained_dims: set[int] | None
|
||||
|
||||
def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants:
|
||||
us = contains_unconstrained(s)
|
||||
return (
|
||||
us, all_unconstrained(s, aval),
|
||||
({i for i, p in enumerate(s._parsed_pspec)
|
||||
if p is PartitionSpec.UNCONSTRAINED} if us else None))
|
||||
unconstrained_dims = ({i for i, p in enumerate(s._parsed_pspec)
|
||||
if p is PartitionSpec.UNCONSTRAINED} if us else None)
|
||||
return UnconstrainedVariants(
|
||||
contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval),
|
||||
unconstrained_dims=unconstrained_dims)
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
@ -1511,13 +1518,13 @@ def lower_jaxpr_to_fun(
|
||||
for is_donated, types in zip(xla_donated_args, input_types)])
|
||||
|
||||
ir_result_shardings = None
|
||||
unconstrained_shardings = None
|
||||
unconstrained_variants = None
|
||||
if result_shardings is not None:
|
||||
ir_result_shardings = util.flatten(
|
||||
[[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
|
||||
for a, s, types in zip(output_avals, result_shardings, output_types)])
|
||||
unconstrained_shardings = util.flatten(
|
||||
[[_get_unconstrained_dimensions(s, a)] * len_ir_types(types)
|
||||
unconstrained_variants = util.flatten(
|
||||
[[_get_unconstrained_variants(s, a)] * len_ir_types(types)
|
||||
for a, s, types in zip(output_avals, result_shardings, output_types)])
|
||||
|
||||
ir_result_memory_kinds = None
|
||||
@ -1633,9 +1640,9 @@ def lower_jaxpr_to_fun(
|
||||
attrs['jax.result_info'] = ir.StringAttr.get(name_)
|
||||
|
||||
if use_sharding_annotations and ir_result_shardings is not None:
|
||||
for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
|
||||
unconstrained_shardings): # type: ignore
|
||||
if sharding is not None and not us[0]:
|
||||
for attrs, sharding, uv in zip(result_attrs, ir_result_shardings,
|
||||
unconstrained_variants): # type: ignore
|
||||
if sharding is not None and not uv.contains_unconstrained:
|
||||
if config.use_shardy_partitioner.value:
|
||||
attrs["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
@ -1716,13 +1723,15 @@ def lower_jaxpr_to_fun(
|
||||
|
||||
if ir_result_shardings is not None:
|
||||
temp_flat_outputs = []
|
||||
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
|
||||
output_avals, unconstrained_shardings): # type: ignore
|
||||
if us[0] and not us[1]:
|
||||
for o, s, o_aval, uv in zip(flat_outputs, ir_result_shardings,
|
||||
output_avals, unconstrained_variants): # type: ignore
|
||||
if (s is not None and uv.contains_unconstrained and
|
||||
not uv.all_unconstrained):
|
||||
if config.use_shardy_partitioner.value:
|
||||
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
|
||||
temp_flat_outputs.append(wrap_with_sharding_op(
|
||||
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
|
||||
entry_lowering_ctx, o, o_aval, s,
|
||||
unspecified_dims=uv.unconstrained_dims))
|
||||
else:
|
||||
temp_flat_outputs.append(o)
|
||||
flat_outputs = temp_flat_outputs
|
||||
|
@ -2157,6 +2157,13 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
|
||||
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
|
||||
donated_invars, out_shardings, out_layouts)
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _abstract_to_concrete_mesh(abstract_mesh, device_assignment):
|
||||
np_dev = np.vectorize(lambda i: device_assignment[i],
|
||||
otypes=[object])(np.arange(len(device_assignment)))
|
||||
return Mesh(np_dev.reshape(abstract_mesh.axis_sizes),
|
||||
abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types)
|
||||
|
||||
def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
|
||||
out_mem_kinds):
|
||||
if device_assignment is None:
|
||||
@ -2164,27 +2171,20 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
|
||||
if len(device_assignment) == 1:
|
||||
return shardings
|
||||
|
||||
np_dev = np.vectorize(lambda i: device_assignment[i],
|
||||
otypes=[object])(np.arange(len(device_assignment)))
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _abstract_to_concrete_mesh(abstract_mesh):
|
||||
return Mesh(
|
||||
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
|
||||
axis_types=abstract_mesh.axis_types)
|
||||
|
||||
out = []
|
||||
for s, a, mem_kind in zip(shardings, avals, out_mem_kinds):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
if a.sharding.mesh.empty:
|
||||
out.append(s)
|
||||
elif a.sharding.mesh._are_all_axes_auto:
|
||||
out.append(s)
|
||||
else:
|
||||
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
|
||||
for sp in a.sharding.spec])
|
||||
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
|
||||
out.append(NamedSharding(
|
||||
_abstract_to_concrete_mesh(a.sharding.mesh), spec,
|
||||
memory_kind=mem_kind))
|
||||
_abstract_to_concrete_mesh(a.sharding.mesh, device_assignment),
|
||||
spec, memory_kind=mem_kind))
|
||||
else:
|
||||
out.append(s)
|
||||
return tuple(out)
|
||||
@ -2534,15 +2534,22 @@ def _get_mesh_pspec_shardings_from_executable(
|
||||
_orig_out_sharding_handlers = {}
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
|
||||
out_s: GSPMDSharding, out_aval, orig_in_s: NamedSharding) -> NamedSharding:
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, NamedSharding)
|
||||
assert isinstance(orig_in_s.mesh, Mesh)
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
||||
if (out_aval is not None and not out_aval.sharding.mesh.empty and
|
||||
out_aval.sharding.mesh._are_all_axes_auto):
|
||||
mesh = _abstract_to_concrete_mesh(
|
||||
out_aval.sharding.mesh, out_s._device_assignment)
|
||||
else:
|
||||
mesh = orig_in_s.mesh
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh)
|
||||
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding
|
||||
|
||||
def _gspmd_to_positional_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
|
||||
out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding
|
||||
) -> PositionalSharding:
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, PositionalSharding)
|
||||
return sharding_impls._op_sharding_to_pos_sharding(
|
||||
@ -2550,7 +2557,8 @@ def _gspmd_to_positional_sharding(
|
||||
_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore
|
||||
|
||||
def _gspmd_to_single_device_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
|
||||
out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding
|
||||
) -> SingleDeviceSharding:
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, SingleDeviceSharding)
|
||||
return SingleDeviceSharding(
|
||||
@ -2565,15 +2573,17 @@ def _get_out_sharding_from_orig_sharding(
|
||||
for o, out_aval in safe_zip(out_shardings, out_avals):
|
||||
if (isinstance(o, sharding_impls.GSPMDSharding) and
|
||||
out_aval is not core.abstract_token):
|
||||
if (orig_aval is not None and out_aval is not None and
|
||||
out_aval.ndim == orig_aval.ndim
|
||||
and sharding_impls.are_op_shardings_equal(
|
||||
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
|
||||
and o.memory_kind == orig_in_s.memory_kind):
|
||||
# TODO(yashkatariya): Remove this condition and ask users to drop into
|
||||
# explicit mode.
|
||||
if (orig_aval is not None and out_aval is not None
|
||||
and out_aval.ndim == orig_aval.ndim
|
||||
and isinstance(orig_in_s, NamedSharding)
|
||||
and out_aval.sharding.mesh == orig_in_s.mesh.abstract_mesh
|
||||
and o.is_equivalent_to(orig_in_s, orig_aval.ndim)):
|
||||
out.append(orig_in_s)
|
||||
else:
|
||||
try:
|
||||
out.append(orig_handler(o, orig_in_s))
|
||||
out.append(orig_handler(o, out_aval, orig_in_s))
|
||||
except:
|
||||
out.append(o)
|
||||
else:
|
||||
@ -2581,40 +2591,9 @@ def _get_out_sharding_from_orig_sharding(
|
||||
return out
|
||||
|
||||
|
||||
def try_matching_out_with_in_spec_for_all_auto(
|
||||
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
|
||||
recover_in_s, recover_in_aval = None, None
|
||||
for in_s, in_aval in safe_zip(in_shardings, in_avals):
|
||||
if isinstance(in_s, NamedSharding):
|
||||
recover_in_s, recover_in_aval = in_s, in_aval
|
||||
break
|
||||
if recover_in_s is None:
|
||||
return new_out_shardings
|
||||
|
||||
res = []
|
||||
for orig_out_s, out_s, out_aval in safe_zip(
|
||||
orig_out_shardings, new_out_shardings, out_avals):
|
||||
if (out_aval is not core.abstract_token and
|
||||
mlir.all_unconstrained(orig_out_s, out_aval) and
|
||||
isinstance(orig_out_s, NamedSharding) and
|
||||
isinstance(out_s, NamedSharding) and
|
||||
orig_out_s.mesh._are_all_axes_auto and out_s.mesh._are_all_axes_auto and
|
||||
out_aval.ndim == recover_in_aval.ndim and
|
||||
out_s.is_equivalent_to(recover_in_s, out_aval.ndim)):
|
||||
res.append(out_s.with_spec(recover_in_s.spec))
|
||||
else:
|
||||
res.append(out_s)
|
||||
return res
|
||||
|
||||
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals,
|
||||
intermediate_shardings=None, context_mesh: Mesh | None = None,
|
||||
orig_out_shardings=None):
|
||||
if orig_out_shardings is not None:
|
||||
new_shardings = try_matching_out_with_in_spec_for_all_auto(
|
||||
orig_out_shardings, new_shardings, new_avals, old_shardings, old_avals)
|
||||
|
||||
intermediate_shardings=None, context_mesh: Mesh | None = None):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
@ -2831,7 +2810,7 @@ def _maybe_get_and_check_out_shardings(
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
try:
|
||||
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # pytype: disable=wrong-arg-types
|
||||
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # pytype: disable=wrong-arg-types
|
||||
except:
|
||||
new_out_shardings.append(xla_s)
|
||||
else:
|
||||
@ -3004,7 +2983,7 @@ class UnloadedMeshExecutable:
|
||||
|
||||
out_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals,
|
||||
intermediate_shardings, context_mesh, orig_out_shardings)
|
||||
intermediate_shardings, context_mesh)
|
||||
|
||||
in_shardings = finalize_shardings(in_shardings, da)
|
||||
out_shardings = finalize_shardings(out_shardings, da)
|
||||
|
@ -131,7 +131,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
return shard_alike(x, y)[1]
|
||||
|
||||
out = f(inp)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertTrue(out.sharding.is_equivalent_to(s, out.ndim))
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_shard_map(self):
|
||||
@ -268,7 +268,8 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
|
||||
x = jax.device_put(np.arange(8), s)
|
||||
_, y = shard_alike(x, jnp.arange(8))
|
||||
self.assertEqual(y.sharding, s)
|
||||
self.assertTrue(y.sharding.is_equivalent_to(s, y.ndim))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user