mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move the replicated trailing dims check inside logical_op_sharding
PiperOrigin-RevId: 611277405
This commit is contained in:
parent
b54ebc3f4c
commit
550ce44afd
@ -2637,7 +2637,6 @@ def _maybe_get_and_check_in_shardings(
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
|
||||
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
@ -2655,7 +2654,7 @@ def _maybe_get_and_check_in_shardings(
|
||||
return new_in_shardings
|
||||
|
||||
|
||||
def _get_out_shardings_from_executable(
|
||||
def _maybe_get_and_check_out_shardings(
|
||||
xla_executable, out_shardings, device_assignment, global_out_avals,
|
||||
num_ordered_effects, all_default_mem_kind
|
||||
):
|
||||
@ -2671,7 +2670,6 @@ def _get_out_shardings_from_executable(
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
|
||||
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
|
||||
new_out_shardings.append(xla_s)
|
||||
else:
|
||||
@ -2823,7 +2821,7 @@ class UnloadedMeshExecutable:
|
||||
in_shardings = _maybe_get_and_check_in_shardings(
|
||||
xla_executable, in_shardings, tuple(da), global_in_avals,
|
||||
len(ordered_effects))
|
||||
out_shardings = _get_out_shardings_from_executable(
|
||||
out_shardings = _maybe_get_and_check_out_shardings(
|
||||
xla_executable, out_shardings, tuple(da), global_out_avals,
|
||||
len(ordered_effects), all_default_mem_kind)
|
||||
else:
|
||||
|
@ -375,6 +375,9 @@ class KeyTyRules:
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
|
||||
# The trailing dims should always be replicated.
|
||||
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)
|
||||
|
||||
if dispatch.is_single_device_sharding(phys_sharding):
|
||||
return phys_sharding
|
||||
elif isinstance(phys_sharding, PmapSharding):
|
||||
@ -475,9 +478,13 @@ class KeyTyRules:
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: GSPMDSharding, aval):
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding._hlo_sharding)
|
||||
num_trailing_dims = core.physical_aval(aval).ndim - aval.ndim
|
||||
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
|
||||
if isinstance(sharding, PmapSharding):
|
||||
return
|
||||
phys_aval = core.physical_aval(aval)
|
||||
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s)
|
||||
num_trailing_dims = phys_aval.ndim - aval.ndim
|
||||
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
|
||||
raise AssertionError(
|
||||
"The trailing dims of extended dtypes should be replicated. Got"
|
||||
|
Loading…
x
Reference in New Issue
Block a user