Move the replicated trailing dims check inside logical_op_sharding

PiperOrigin-RevId: 611277405
This commit is contained in:
Yash Katariya 2024-02-28 17:03:04 -08:00 committed by jax authors
parent b54ebc3f4c
commit 550ce44afd
2 changed files with 12 additions and 7 deletions

View File

@ -2637,7 +2637,6 @@ def _maybe_get_and_check_in_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
@ -2655,7 +2654,7 @@ def _maybe_get_and_check_in_shardings(
return new_in_shardings
def _get_out_shardings_from_executable(
def _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, device_assignment, global_out_avals,
num_ordered_effects, all_default_mem_kind
):
@ -2671,7 +2670,6 @@ def _get_out_shardings_from_executable(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
else:
@ -2823,7 +2821,7 @@ class UnloadedMeshExecutable:
in_shardings = _maybe_get_and_check_in_shardings(
xla_executable, in_shardings, tuple(da), global_in_avals,
len(ordered_effects))
out_shardings = _get_out_shardings_from_executable(
out_shardings = _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, tuple(da), global_out_avals,
len(ordered_effects), all_default_mem_kind)
else:

View File

@ -375,6 +375,9 @@ class KeyTyRules:
@staticmethod
def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
# The trailing dims should always be replicated.
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)
if dispatch.is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
@ -475,9 +478,13 @@ class KeyTyRules:
return random_wrap(physical_result, impl=aval.dtype._impl)
@staticmethod
def check_replicated_trailing_dims(sharding: GSPMDSharding, aval):
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding._hlo_sharding)
num_trailing_dims = core.physical_aval(aval).ndim - aval.ndim
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
if isinstance(sharding, PmapSharding):
return
phys_aval = core.physical_aval(aval)
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s)
num_trailing_dims = phys_aval.ndim - aval.ndim
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
raise AssertionError(
"The trailing dims of extended dtypes should be replicated. Got"