From b322d399e1fd68521c94b4207ca35aabd5858ca3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 4 Apr 2024 15:38:49 -0700 Subject: [PATCH] Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted PiperOrigin-RevId: 621991853 --- jax/_src/interpreters/pxla.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f561bbde9..f173bc50b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2758,21 +2758,16 @@ def _maybe_get_and_check_in_shardings( xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: - # TODO(yashkatariya): Remove the if branch for abstract_token once - # choosing input shardings by XLA is enabled again. - if aval is core.abstract_token: - new_in_shardings.append(orig) - else: - xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore - orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore - # MANUAL HloSharding comes from other partitioning frameworks. - if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and - not xla_hlo_s.is_manual() and - (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))): - raise AssertionError( - f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " - "(User sharding)") - new_in_shardings.append(orig) + xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore + orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore + # MANUAL HloSharding comes from other partitioning frameworks. + if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and + not xla_hlo_s.is_manual() and + (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))): + raise AssertionError( + f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " + "(User sharding)") + new_in_shardings.append(orig) return new_in_shardings