From 12730280180ea28c19f87942b295410003cf5b55 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 3 Jun 2024 14:52:08 -0700 Subject: [PATCH] Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. PiperOrigin-RevId: 639920049 --- jax/BUILD | 1 + jax/_src/dispatch.py | 10 +-- jax/_src/earray.py | 5 +- jax/_src/interpreters/mlir.py | 19 ++++- jax/_src/interpreters/pxla.py | 27 ++------ jax/_src/lax/lax.py | 15 ---- jax/_src/pjit.py | 2 +- jax/_src/prng.py | 126 ++++------------------------------ jax/_src/sharding_impls.py | 109 +++++++++++++++++++++++++++++ jax/experimental/shard_map.py | 4 +- tests/dtypes_test.py | 13 ---- tests/lax_test.py | 16 ----- 12 files changed, 151 insertions(+), 196 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index bca6789ca..359978fe3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -789,6 +789,7 @@ pytype_strict_library( srcs = ["_src/sharding_impls.py"], deps = [ ":config", + ":core", ":mesh", ":op_shardings", ":partition_spec", diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 220fe063b..a59126a65 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -48,8 +48,8 @@ from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding, - GSPMDSharding, TransferToMemoryKind) + SingleDeviceSharding, NamedSharding, XLACompatibleSharding, + GSPMDSharding, TransferToMemoryKind, is_single_device_sharding) from jax._src.layout import Layout, DeviceLocalLayout @@ -163,12 +163,6 @@ def wait_for_tokens(): runtime_tokens.block_until_ready() -def is_single_device_sharding(sharding: Sharding) -> bool: - # Special case PmapSharding here because PmapSharding maps away an axis - # and needs to be handled separately.test_pjit_single_device_sharding_add - return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) - - @contextlib.contextmanager def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): if _on_exit: diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 176abcc5a..fcf0e9c3c 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -20,6 +20,7 @@ from jax._src import api_util from jax._src import basearray from jax._src import core from jax._src import tree_util +from jax._src import sharding_impls from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.util import safe_zip, safe_map @@ -80,7 +81,7 @@ class EArray(basearray.Array): @property def sharding(self): phys_sharding = self._data.sharding - return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding) + return sharding_impls.logical_sharding(self.aval, phys_sharding) # TODO(mattjj): not implemented below here, need more methods from ArrayImpl @@ -99,7 +100,7 @@ class EArray(basearray.Array): def _earray_shard_arg_handler(x, sharding): arr = x._data - phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding) + phys_sharding = sharding_impls.physical_sharding(x.aval, sharding) return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f306af7db..62bb951f2 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -818,7 +818,7 @@ def _to_physical_op_sharding( return _to_physical_op_sharding(aval.inner_aval, sharding) assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if dtypes.issubdtype(aval.dtype, dtypes.extended): - sharding = aval.dtype._rules.physical_sharding(aval, sharding) + sharding = sharding_impls.physical_sharding(aval, sharding) aval = core.physical_aval(aval) return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore @@ -1376,7 +1376,7 @@ def lower_jaxpr_to_fun( if ir_arg_shardings is not None and name == "main": flat_args = [ - a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error + replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error for o, s, a in zip(flat_args, ir_arg_shardings, input_avals) @@ -1417,7 +1417,7 @@ def lower_jaxpr_to_fun( if ir_result_shardings is not None and name == "main": flat_outputs = [ - a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error + replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals) @@ -1441,6 +1441,19 @@ def wrap_with_memory_kind( return op.result +def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: + # Set the sharding of extended dtypes to be UNCONSTRAINED + # (i.e. XLA will choose) on aval.shape. + # For the trailing dims i.e. the dimension of key_shape on the base_array, + # the sharding is set to be REPLICATED always. + # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), + # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). + # The below custom call achieves the sharding like above example. + return wrap_with_sharding_op( + ctx, val, aval, xc.HloSharding.replicate().to_proto(), + unspecified_dims=set(range(aval.ndim))) + + def _emit_lowering_rule_as_fun(lowering_rule, ctx: LoweringRuleContext) -> func_dialect.FuncOp: """Emits the contents of a lowering rule as a private function.""" diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 0253fb697..f1a46a386 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2563,20 +2563,10 @@ def _register_out_sharding_handler( ) -> None: _orig_out_sharding_handlers[sharding_cls] = handler - -def _gspmd_to_named_sharding_via_mesh( - out_s: sharding_impls.GSPMDSharding, - mesh: Mesh) -> sharding_impls.NamedSharding: - parsed_pspec = sharding_impls.parse_flatten_op_sharding( - out_s._hlo_sharding, mesh)[0] - return create_mesh_pspec_sharding( - mesh, parsed_pspec.get_partition_spec(), parsed_pspec, - out_s.memory_kind) - def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: - return _gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) + return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler( sharding_impls.NamedSharding, _gspmd_to_named_sharding) @@ -2793,7 +2783,7 @@ def _maybe_get_and_check_in_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) + xla_s = sharding_impls.logical_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) @@ -2829,7 +2819,7 @@ def _maybe_get_and_check_out_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) + xla_s = sharding_impls.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) @@ -3120,7 +3110,7 @@ class MeshExecutable(stages.XlaExecutable): kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args_flat))] in_shardings = [ - a.dtype._rules.physical_sharding(a, s) + sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) else s for s, a in zip(self._in_shardings, self.in_avals) @@ -3186,14 +3176,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): return in_shardings, out_shardings, committed, tuple(local_devices) -@util.cache() -def create_mesh_pspec_sharding( - mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None, - memory_kind: str | None = None) -> sharding_impls.NamedSharding: - if pspec is None: - pspec, parsed_pspec = PartitionSpec(), None - return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, - memory_kind=memory_kind) +create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding def check_device_backend_on_shardings(shardings) -> bool: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e78b34b8f..431319db6 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5209,14 +5209,6 @@ class BIntRules: return core.DArray(aval, phys_handler(bufs)) return handler - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - - @staticmethod - def physical_sharding(aval, sharding): - return sharding - @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) @@ -5225,12 +5217,5 @@ class BIntRules: def convert_to(other_dtype, bint_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) - @staticmethod - def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: - return val - - @staticmethod - def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): - pass core.bint._rules = BIntRules diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8e71d56a2..38f06d84d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -269,7 +269,7 @@ def _get_fastpath_data( kept_var_bitvec = [i in executable._kept_var_idx for i in range(len(args_flat))] in_shardings = [ - a.dtype._rules.physical_sharding(a, s) + sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) else s for s, a in zip(executable._in_shardings, executable.in_avals) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 57038e463..4d4a045a3 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -33,11 +33,9 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp -from jax._src import sharding_specs from jax._src import source_info_util from jax._src import tree_util as tree_util_internal from jax._src import typing -from jax._src import op_shardings from jax._src.api import jit, vmap from jax._src.dtypes import float0 from jax._src.interpreters import ad @@ -53,9 +51,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) -from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding) + NamedSharding, PmapSharding, physical_sharding, logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -235,8 +232,7 @@ class PRNGKeyArray(jax.Array): @property def sharding(self): - phys_sharding = self._base_array.sharding - return KeyTyRules.logical_sharding(self.aval, phys_sharding) + return logical_sharding(self.aval, self._base_array.sharding) def _is_scalar(self): base_ndim = len(self._impl.key_shape) @@ -324,53 +320,6 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] -def make_key_array_phys_sharding(aval, sharding): - if dispatch.is_single_device_sharding(sharding): - return sharding - elif isinstance(sharding, PmapSharding): - key_shape = aval.dtype._impl.key_shape - trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape) - phys_sharding_spec = sharding_specs.ShardingSpec( - sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), - mesh_mapping=sharding.sharding_spec.mesh_mapping) - return PmapSharding(devices=sharding.devices, - sharding_spec=phys_sharding_spec) - elif isinstance(sharding, NamedSharding): - key_shape = aval.dtype._impl.key_shape - trailing_spec = [None] * len(key_shape) - return NamedSharding( - sharding.mesh, - PartitionSpec(*sharding.spec, *trailing_spec)) - else: - hlos = sharding._to_xla_hlo_sharding(aval.ndim) - return GSPMDSharding( - sharding._device_assignment, physical_hlo_sharding(aval, hlos)) - - -def get_logical_gspmd_sharding(aval, phys_sharding): - key_shape = aval.dtype._impl.key_shape - phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)) - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - phys_hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - # Create logical sharding by cutting off the replicated trailing dims. - logical_op_sharding = phys_hlo_sharding.to_proto().clone() - tad = partitions[:-len(key_shape)] + suffix - logical_op_sharding.tile_assignment_dimensions = tad - return GSPMDSharding(phys_sharding._device_assignment, - xc.HloSharding.from_proto(logical_op_sharding)) - - -def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - key_shape = aval.dtype._impl.key_shape - new_op_sharding = hlo_sharding.to_proto().clone() - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - tad = partitions + [1] * len(key_shape) + suffix - new_op_sharding.tile_assignment_dimensions = tad - return xc.HloSharding.from_proto(new_op_sharding) - class KeyTyRules: @@ -393,32 +342,6 @@ class KeyTyRules: def physical_const(val) -> Array: return val._base_array - @staticmethod - def physical_sharding( - aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: - return make_key_array_phys_sharding(aval, sharding) - - @staticmethod - def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: - # The trailing dims should always be replicated. - aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval) - - if dispatch.is_single_device_sharding(phys_sharding): - return phys_sharding - elif isinstance(phys_sharding, PmapSharding): - key_shape = aval.dtype._impl.key_shape - logical_sharding_spec = sharding_specs.ShardingSpec( - sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)], - mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) - return PmapSharding(devices=phys_sharding.devices, - sharding_spec=logical_sharding_spec) - elif isinstance(phys_sharding, NamedSharding): - logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) - return pxla._gspmd_to_named_sharding_via_mesh( - logical_gs, phys_sharding.mesh) - else: - return get_logical_gspmd_sharding(aval, phys_sharding) - @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -434,7 +357,7 @@ class KeyTyRules: # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): - phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_sharding = physical_sharding(aval, sharding) else: assert False, f'impossible sharding {sharding} in local sharded result handler' @@ -456,7 +379,7 @@ class KeyTyRules: phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - phys_sharding = make_key_array_phys_sharding(aval, out_sharding) + phys_sharding = physical_sharding(aval, out_sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) @@ -468,7 +391,7 @@ class KeyTyRules: phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_arrays = [random_unwrap(arr) for arr in arrays] - phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_sharding = physical_sharding(aval, sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) @@ -477,8 +400,9 @@ class KeyTyRules: def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) - physical_sharding = make_key_array_phys_sharding(aval, sharding) - physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) + phys_sharding = physical_sharding(aval, sharding) + physical_result = pxla.batched_device_put(physical_aval, phys_sharding, + physical_buffers, list(devices)) return random_wrap(physical_result, impl=aval.dtype._impl) @staticmethod @@ -486,37 +410,11 @@ class KeyTyRules: physical_aval = core.physical_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) - physical_sharding = make_key_array_phys_sharding(aval, sharding) - physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) + phys_sharding = physical_sharding(aval, sharding) + physical_result = pxla.batched_device_put( + physical_aval, phys_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) - @staticmethod - def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): - if isinstance(sharding, PmapSharding): - return - phys_aval = core.physical_aval(aval) - hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) - partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s) - num_trailing_dims = phys_aval.ndim - aval.ndim - if not all(i == 1 for i in partitions[-num_trailing_dims:]): - raise AssertionError( - "The trailing dims of extended dtypes should be replicated. Got" - f" sharding: {sharding}, partitions: {partitions}, " - f"num_trailing_dims: {num_trailing_dims}") - - @staticmethod - def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: - # Set the sharding of extended dtypes to be UNCONSTRAINED - # (i.e. XLA will choose) on aval.shape. - # For the trailing dims i.e. the dimension of key_shape on the base_array, - # the sharding is set to be REPLICATED always. - # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), - # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). - # The below custom call achieves the sharding like above example. - return mlir.wrap_with_sharding_op( - ctx, val, aval, xc.HloSharding.replicate().to_proto(), - unspecified_dims=set(range(aval.ndim))) - @staticmethod def tangent_dtype(_): return dtypes.float0 @@ -571,7 +469,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): arr = x._base_array - phys_sharding = make_key_array_phys_sharding(x.aval, sharding) + phys_sharding = physical_sharding(x.aval, sharding) return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 7a03e5ccc..c4299e759 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -30,6 +30,7 @@ from jax._src import sharding_specs from jax._src import tree_util from jax._src import util from jax._src import xla_bridge +from jax._src import core from jax._src.lib import xla_client as xc from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated, @@ -1437,3 +1438,111 @@ def num_addressable_indices( }) shard_size = tensor_sharding.shard_shape(global_shape)[dim] return shard_size * num_unique_slices + + +def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + new_op_sharding = hlo_sharding.to_proto().clone() + partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * elt_aval.ndim + suffix + new_op_sharding.tile_assignment_dimensions = tad + return xc.HloSharding.from_proto(new_op_sharding) + +def is_single_device_sharding(sharding: sharding.Sharding) -> bool: + # Special case PmapSharding here because PmapSharding maps away an axis + # and needs to be handled separately.test_pjit_single_device_sharding_add + return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) + +def make_key_array_phys_sharding(aval, sharding): + if is_single_device_sharding(sharding): + return sharding + elif isinstance(sharding, PmapSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim + phys_sharding_spec = sharding_specs.ShardingSpec( + sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), + mesh_mapping=sharding.sharding_spec.mesh_mapping) + return PmapSharding(devices=sharding.devices, + sharding_spec=phys_sharding_spec) + elif isinstance(sharding, NamedSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + trailing_spec = [None] * elt_aval.ndim + return NamedSharding( + sharding.mesh, + PartitionSpec(*sharding.spec, *trailing_spec)) + else: + hlos = sharding._to_xla_hlo_sharding(aval.ndim) + return GSPMDSharding( + sharding._device_assignment, physical_hlo_sharding(aval, hlos)) + + +def physical_sharding( + aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: + return make_key_array_phys_sharding(aval, sharding) + + +def get_logical_gspmd_sharding(aval, phys_sharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + elt_aval.ndim) + partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-elt_aval.ndim] + suffix + logical_op_sharding.tile_assignment_dimensions = tad + return GSPMDSharding(phys_sharding._device_assignment, + xc.HloSharding.from_proto(logical_op_sharding)) + +def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): + if isinstance(sharding, PmapSharding): + return + phys_aval = core.physical_aval(aval) + hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) + partitions, _ = get_num_ways_dim_sharded(hlo_s) + num_trailing_dims = phys_aval.ndim - aval.ndim + if not all(i == 1 for i in partitions[-num_trailing_dims:]): + raise AssertionError( + "The trailing dims of extended dtypes should be replicated. Got" + f" sharding: {sharding}, partitions: {partitions}, " + f"num_trailing_dims: {num_trailing_dims}") + +def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: + # The trailing dims should always be replicated. + check_replicated_trailing_dims(phys_sharding, aval) + + if is_single_device_sharding(phys_sharding): + return phys_sharding + elif isinstance(phys_sharding, PmapSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + logical_sharding_spec = sharding_specs.ShardingSpec( + sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim], + mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) + return PmapSharding(devices=phys_sharding.devices, + sharding_spec=logical_sharding_spec) + elif isinstance(phys_sharding, NamedSharding): + logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + return _gspmd_to_named_sharding_via_mesh( + logical_gs, phys_sharding.mesh) + else: + return get_logical_gspmd_sharding(aval, phys_sharding) + + +@util.cache() +def create_mesh_pspec_sharding( + mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, parsed_pspec=None, + memory_kind: str | None = None) -> NamedSharding: + if pspec is None: + pspec, parsed_pspec = PartitionSpec(), None + return NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, + memory_kind=memory_kind) + + +def _gspmd_to_named_sharding_via_mesh( + out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding: + parsed_pspec = parse_flatten_op_sharding( + out_s._hlo_sharding, mesh)[0] + return create_mesh_pspec_sharding( + mesh, parsed_pspec.get_partition_spec(), parsed_pspec, + out_s.memory_kind) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 86965b251..2fabdcf38 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -653,7 +653,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = aval_in.dtype._rules.physical_sharding(aval_in, ns) + ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() unspecified = set(range(aval_in.ndim)) if auto else set() @@ -667,7 +667,7 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - ns = aval_out.dtype._rules.physical_sharding(aval_out, ns) + ns = sharding_impls.physical_sharding(aval_out, ns) aval_out = core.physical_aval(aval_out) unspecified = set(range(aval_out.ndim)) if auto else set() manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 363bb39fe..668ea4c0b 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -578,15 +578,6 @@ class EArrayTest(jtu.JaxTestCase): def physical_element_aval(foo_dtype): return core.ShapedArray((), dtypes.dtype('float32')) - @staticmethod - def replicate_trailing_dims(ctx, val, aval): - del ctx, aval - return val - - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): phys_sharding = out_sharding # unlike KeyTyRules, assume same shape @@ -595,10 +586,6 @@ class EArrayTest(jtu.JaxTestCase): phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) return lambda bufs: earray.EArray(aval, phys_handler(bufs)) - @staticmethod - def physical_sharding(aval, sharding): - return sharding # unlike KeyTyRules, assume same shape - @dataclasses.dataclass(frozen=True) class FooTy(dtypes.ExtendedDType): name: str = 'foo' diff --git a/tests/lax_test.py b/tests/lax_test.py index fdcf71b99..b69d623e7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3315,14 +3315,6 @@ class FooTyRules: def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((2,), jnp.dtype('uint32')) - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - - @staticmethod - def physical_sharding(aval, sharding): - return sharding - @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -3341,14 +3333,6 @@ class FooTyRules: return FooArray(aval.shape, buf) return handler - @staticmethod - def replicate_trailing_dims(ctx, val, aval): - return val - - @staticmethod - def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): - pass - class FooTy(dtypes.ExtendedDType): type = dtypes.extended