mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted
PiperOrigin-RevId: 621991853
This commit is contained in:
parent
033992867f
commit
b322d399e1
@ -2758,21 +2758,16 @@ def _maybe_get_and_check_in_shardings(
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
# TODO(yashkatariya): Remove the if branch for abstract_token once
|
||||
# choosing input shardings by XLA is enabled again.
|
||||
if aval is core.abstract_token:
|
||||
new_in_shardings.append(orig)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
# MANUAL HloSharding comes from other partitioning frameworks.
|
||||
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
||||
not xla_hlo_s.is_manual() and
|
||||
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_in_shardings.append(orig)
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
# MANUAL HloSharding comes from other partitioning frameworks.
|
||||
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
||||
not xla_hlo_s.is_manual() and
|
||||
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_in_shardings.append(orig)
|
||||
return new_in_shardings
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user