mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted
PiperOrigin-RevId: 621991853
This commit is contained in:
parent
033992867f
commit
b322d399e1
@ -2757,11 +2757,6 @@ def _maybe_get_and_check_in_shardings(
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
# TODO(yashkatariya): Remove the if branch for abstract_token once
|
||||
# choosing input shardings by XLA is enabled again.
|
||||
if aval is core.abstract_token:
|
||||
new_in_shardings.append(orig)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
|
Loading…
x
Reference in New Issue
Block a user