From 550ce44afdef0e9ca4e0f80ea1c5a718265f787d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Feb 2024 17:03:04 -0800 Subject: [PATCH] Move the replicated trailing dims check inside logical_op_sharding PiperOrigin-RevId: 611277405 --- jax/_src/interpreters/pxla.py | 6 ++---- jax/_src/prng.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3b2065491..a339122fe 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2637,7 +2637,6 @@ def _maybe_get_and_check_in_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: @@ -2655,7 +2654,7 @@ def _maybe_get_and_check_in_shardings( return new_in_shardings -def _get_out_shardings_from_executable( +def _maybe_get_and_check_out_shardings( xla_executable, out_shardings, device_assignment, global_out_avals, num_ordered_effects, all_default_mem_kind ): @@ -2671,7 +2670,6 @@ def _get_out_shardings_from_executable( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) new_out_shardings.append(xla_s) else: @@ -2823,7 +2821,7 @@ class UnloadedMeshExecutable: in_shardings = _maybe_get_and_check_in_shardings( xla_executable, in_shardings, tuple(da), global_in_avals, len(ordered_effects)) - out_shardings = _get_out_shardings_from_executable( + out_shardings = _maybe_get_and_check_out_shardings( xla_executable, out_shardings, tuple(da), global_out_avals, len(ordered_effects), all_default_mem_kind) else: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index a1f2906f7..8cf8384ff 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -375,6 +375,9 @@ class KeyTyRules: @staticmethod def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding: + # The trailing dims should always be replicated. + aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval) + if dispatch.is_single_device_sharding(phys_sharding): return phys_sharding elif isinstance(phys_sharding, PmapSharding): @@ -475,9 +478,13 @@ class KeyTyRules: return random_wrap(physical_result, impl=aval.dtype._impl) @staticmethod - def check_replicated_trailing_dims(sharding: GSPMDSharding, aval): - partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding._hlo_sharding) - num_trailing_dims = core.physical_aval(aval).ndim - aval.ndim + def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): + if isinstance(sharding, PmapSharding): + return + phys_aval = core.physical_aval(aval) + hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) + partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s) + num_trailing_dims = phys_aval.ndim - aval.ndim if not all(i == 1 for i in partitions[-num_trailing_dims:]): raise AssertionError( "The trailing dims of extended dtypes should be replicated. Got"