Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted

PiperOrigin-RevId: 621991853
This commit is contained in:
Yash Katariya 2024-04-04 15:38:49 -07:00 committed by jax authors
parent 033992867f
commit b322d399e1

View File

@ -2758,21 +2758,16 @@ def _maybe_get_and_check_in_shardings(
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
# TODO(yashkatariya): Remove the if branch for abstract_token once
# choosing input shardings by XLA is enabled again.
if aval is core.abstract_token:
new_in_shardings.append(orig)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
return new_in_shardings