diff --git a/CHANGELOG.md b/CHANGELOG.md index d9f937d7e..9a9beb412 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ Remember to align the itemized text with the first line of an item within a list * JAX now requires ml_dtypes version 0.4.0 or newer. * Deprecations + * `jax.sharding.XLACompatibleSharding` is deprecated. Please use + `jax.sharding.Sharding`. * Removed a number of previously-deprecated APIs: * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` * from {mod}`jax.lax`: `tie_in` diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 7b1393d8e..954f62b8a 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -10,9 +10,6 @@ Classes .. autoclass:: Sharding :members: -.. autoclass:: XLACompatibleSharding - :members: - :show-inheritance: .. autoclass:: SingleDeviceSharding :members: :show-inheritance: diff --git a/jax/BUILD b/jax/BUILD index 9c0e787f9..b88abb9a3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -569,6 +569,7 @@ pytype_strict_library( ":partial_eval", ":path", ":pickle_util", + ":sharding", ":sharding_impls", ":source_info_util", ":state_types", diff --git a/jax/_src/api.py b/jax/_src/api.py index 9ce2fb224..41c230c7f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -67,8 +67,7 @@ from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind, - XLACompatibleSharding) +from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util @@ -2428,7 +2427,7 @@ def _check_sharding(aval, s): if isinstance(s, Sharding): if isinstance(aval, core.AbstractToken): aval = core.token_shaped_array - if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding): + if not isinstance(s, PmapSharding): pjit.pjit_check_aval_sharding( (s,), (aval,), None, "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible diff --git a/jax/_src/array.py b/jax/_src/array.py index 2a265f582..854975c97 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, XLACompatibleSharding, + PmapSharding, SingleDeviceSharding, device_replica_id_map, hashed_index, num_addressable_indices) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method @@ -218,9 +218,8 @@ class ArrayImpl(basearray.Array): f"{self.aval.str_short()} with sharding {self.sharding}") # Rearrange arrays based on the device assignment. - if isinstance(self.sharding, XLACompatibleSharding): - addressable_da = self.sharding._addressable_device_assignment - self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] + addressable_da = self.sharding._addressable_device_assignment + self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] @property def shape(self) -> Shape: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index a59126a65..2836fd9b2 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -48,7 +48,7 @@ from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - SingleDeviceSharding, NamedSharding, XLACompatibleSharding, + SingleDeviceSharding, NamedSharding, GSPMDSharding, TransferToMemoryKind, is_single_device_sharding) from jax._src.layout import Layout, DeviceLocalLayout @@ -220,7 +220,7 @@ class SourceInfo(NamedTuple): def jaxpr_shardings( jaxpr: core.Jaxpr, -) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]: +) -> Iterator[tuple[Sharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map @@ -241,7 +241,7 @@ def jaxpr_shardings( for names in [*eqn.params['in_names'], *eqn.params['out_names']]) elif eqn.primitive is device_put_p: s = eqn.params['device'] - if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None: + if isinstance(s, Sharding) and s.memory_kind is not None: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield (s, source_info) for subjaxpr in core.subjaxprs(jaxpr): @@ -392,7 +392,7 @@ def _device_put_sharding_impl(x, aval, device): isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): # This has to be XLACompatible because _mcjax_reshard will run a # XLA computation. - assert isinstance(s, XLACompatibleSharding) + assert isinstance(s, Sharding) return _mcjax_reshard(x, s) if not s.is_fully_addressable: # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. @@ -467,11 +467,11 @@ ad.deflinear2(device_put_p, device_put_transpose_rule) batching.defvectorized(device_put_p) def _tpu_gpu_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and + if (isinstance(device, (Sharding, TransferToMemoryKind)) and device.memory_kind is not None): aval, = ctx.avals_in out_aval, = ctx.avals_out - if isinstance(device, XLACompatibleSharding): + if isinstance(device, Sharding): x = mlir.wrap_with_sharding_op( ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) @@ -484,7 +484,7 @@ mlir.register_lowering( def _common_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and + if (isinstance(device, (Sharding, TransferToMemoryKind)) and device.memory_kind is not None): raise NotImplementedError( "Passing memory_kind to device_put via Shardings is not supported on" @@ -493,7 +493,7 @@ def _common_device_put_lowering(ctx, x, *, device, src): mlir.register_lowering(device_put_p, _common_device_put_lowering) def _propagate_mem_kind_dp(xm, device=None, src=None): - if isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)): + if isinstance(device, (Sharding, TransferToMemoryKind)): return device.memory_kind return None pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 62bb951f2..f66e4bd6a 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -47,6 +47,7 @@ from jax._src import xla_bridge as xb from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout +from jax._src.sharding import Sharding as JSharding from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects @@ -54,7 +55,6 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir import register_jax_dialects -from jax._src.sharding_impls import XLACompatibleSharding from jax._src.state.types import AbstractRef map, unsafe_map = util.safe_map, map @@ -735,7 +735,7 @@ def flatten_lowering_ir_args( _module_name_regex = re.compile(r"[^\w.-]") def sharded_aval(aval: core.AbstractValue, - sharding: XLACompatibleSharding | None) -> core.AbstractValue: + sharding: JSharding | None) -> core.AbstractValue: """Returns the new aval sharded based on sharding proto.""" if sharding is None: return aval @@ -809,11 +809,11 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] def _to_physical_op_sharding( - aval: core.AbstractValue, sharding: XLACompatibleSharding | None, + aval: core.AbstractValue, sharding: JSharding | None, ) -> xc.OpSharding | None: if sharding is None: return None - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(sharding, JSharding) if isinstance(aval, AbstractRef): return _to_physical_op_sharding(aval.inner_aval, sharding) assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) @@ -831,10 +831,10 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None: return layout._to_xla_layout() -def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None: +def _get_mem_kind(s: JSharding | None) -> str | None: if s is None: return None - assert isinstance(s, sharding_impls.XLACompatibleSharding) + assert isinstance(s, JSharding) return s.memory_kind @@ -849,8 +849,8 @@ def lower_jaxpr_to_module( name_stack: source_info_util.NameStack, donated_args: Sequence[bool], replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, - result_shardings: Sequence[XLACompatibleSharding | None] | None = None, + arg_shardings: Sequence[JSharding | None] | None = None, + result_shardings: Sequence[JSharding | None] | None = None, in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, arg_names: Sequence[str | None] | None = None, @@ -1090,8 +1090,8 @@ def lower_jaxpr_to_fun( *, public: bool = False, replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, - result_shardings: Sequence[XLACompatibleSharding | None] | None = None, + arg_shardings: Sequence[JSharding | None] | None = None, + result_shardings: Sequence[JSharding | None] | None = None, use_sharding_annotations: bool = True, input_output_aliases: Sequence[int | None] | None = None, xla_donated_args: Sequence[bool] | None = None, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f1a46a386..04cdf8e5d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -66,6 +66,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, @@ -114,9 +115,8 @@ def shard_arg(arg, sharding, canonicalize=True): @profiler.annotate_function -def shard_args( - shardings: Sequence[sharding_impls.XLACompatibleSharding], args, -) -> Sequence[jax.Array]: +def shard_args(shardings: Sequence[JSharding], args + ) -> Sequence[jax.Array]: return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)] shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {} @@ -155,7 +155,7 @@ def _shard_mutable_array(x, sharding): shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, - sharding: jax.sharding.Sharding, xs: Sequence[Any], + sharding: JSharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): from jax._src import array @@ -191,7 +191,7 @@ _shard_aval_handlers[ShapedArray] = _shard_abstract_array def local_aval_to_result_handler( aval: core.AbstractValue, - sharding: sharding_impls.XLACompatibleSharding, + sharding: JSharding, indices: tuple[Index, ...] | None, ) -> Callable[[list[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -864,9 +864,9 @@ class UnloadedPmapExecutable: compiled: Any backend: xb.XlaBackend local_input_avals: Sequence[core.AbstractValue] - input_shardings: Sequence[sharding_impls.XLACompatibleSharding] + input_shardings: Sequence[JSharding] local_output_avals: Sequence[ShapedArray] - output_shardings: Sequence[sharding_impls.XLACompatibleSharding] + output_shardings: Sequence[JSharding] unordered_effects: list[core.Effect] ordered_effects: list[core.Effect] keepalive: Sequence[Any] @@ -1096,7 +1096,7 @@ class ResultsHandler: def local_avals_to_results_handler( unmapped_local_out_avals: Sequence[ShapedArray], - local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler: + local_shardings: Sequence[JSharding]) -> ResultsHandler: out_indices = [tuple(s.devices_indices_map(aval.shape).values()) for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)] handlers = [ @@ -1108,7 +1108,7 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], - shardings: Sequence[sharding_impls.XLACompatibleSharding], + shardings: Sequence[JSharding], committed: bool) -> ResultsHandler: handlers = [ global_aval_to_result_handler(global_aval, s, committed) @@ -1617,8 +1617,7 @@ TilingMethod = Union[TileVectorize, TileManual] def check_if_any_auto( - shardings: Iterable[(sharding_impls.XLACompatibleSharding | - AUTO | UnspecifiedValue)]) -> bool: + shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool: for s in shardings: if is_auto(s): return True @@ -1683,7 +1682,7 @@ class DeviceAssignmentMismatchError(Exception): ShardingInfo = tuple[ - Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO], + Union[JSharding, UnspecifiedValue, AUTO], MismatchType, Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports ] @@ -1740,7 +1739,7 @@ def _get_and_check_device_assignment( final_device_assignment = first_sharding_info[0] return xb.get_device_backend(final_device_assignment[0]), final_device_assignment -MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue] +MaybeSharding = Union[JSharding, UnspecifiedValue] def prune_unused_inputs( @@ -1928,8 +1927,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, nreps = dispatch.jaxpr_replicas(jaxpr) _raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr) - in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None - out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None + in_mlir_shardings: list[JSharding | None] | None + out_mlir_shardings: list[JSharding | None] | None axis_ctx: mlir.AxisContext if nreps == 1: @@ -2068,8 +2067,7 @@ class AllArgsInfo(NamedTuple): @lru_cache(maxsize=2048) -def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding, - ndim: int) -> GSPMDSharding: +def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding: if isinstance(s, GSPMDSharding): return s return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), @@ -2242,11 +2240,11 @@ def lower_sharding_computation( def _to_logical_sharding( aval: core.AbstractValue, sharding: MaybeSharding | AUTO -) -> sharding_impls.XLACompatibleSharding | None: +) -> JSharding | None: if is_unspecified(sharding) or is_auto(sharding): return None elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)): - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(sharding, JSharding) return sharding elif isinstance(aval, core.AbstractToken): return None @@ -2339,8 +2337,8 @@ def lower_mesh_computation( # 2. Build up the HLO tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform) - in_partitions: list[sharding_impls.XLACompatibleSharding | None] | None - out_partitions: list[sharding_impls.XLACompatibleSharding | None] | None + in_partitions: list[JSharding | None] | None + out_partitions: list[JSharding | None] | None axis_ctx: mlir.AxisContext if spmd_lowering: in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings) @@ -2554,7 +2552,7 @@ def _get_mesh_pspec_shardings_from_executable( _orig_out_sharding_handlers = {} -_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding) +_ShardingT = TypeVar("_ShardingT", bound=JSharding) def _register_out_sharding_handler( @@ -2849,9 +2847,9 @@ class UnloadedMeshExecutable: device_assignment: xc.DeviceList | Sequence[xc.Device] backend: xb.XlaBackend input_avals: Sequence[ShapedArray] - input_shardings: Sequence[sharding_impls.XLACompatibleSharding] + input_shardings: Sequence[JSharding] output_avals: Sequence[ShapedArray] - output_shardings: Sequence[sharding_impls.XLACompatibleSharding] + output_shardings: Sequence[JSharding] committed: bool name: str unordered_effects: list[core.Effect] @@ -2891,9 +2889,8 @@ class UnloadedMeshExecutable: hlo: ir.Module, global_in_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray], - in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO], - out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO | - UnspecifiedValue)], + in_shardings: Sequence[JSharding | AUTO], + out_shardings: Sequence[(JSharding | AUTO | UnspecifiedValue)], spmd_lowering: bool, tuple_args: bool, auto_spmd_lowering: bool, @@ -3000,8 +2997,8 @@ class UnloadedMeshExecutable: class MeshExecutableFastpathData(NamedTuple): xla_executable: xc.LoadedExecutable out_pytree_def: Any - in_shardings: Sequence[sharding_impls.XLACompatibleSharding] - out_shardings: Sequence[sharding_impls.XLACompatibleSharding] + in_shardings: Sequence[JSharding] + out_shardings: Sequence[JSharding] out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] @@ -3077,10 +3074,10 @@ class MeshExecutable(stages.XlaExecutable): self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable - def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[JSharding]: return self._in_shardings - def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[JSharding]: return self._out_shardings def input_layouts(self): @@ -3190,7 +3187,7 @@ def check_device_backend_on_shardings(shardings) -> bool: def check_array_xla_sharding_layout_match( args_after_dce, - in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding], + in_xla_shardings: Sequence[JSharding], in_xla_layouts: Sequence[DeviceLocalLayout], jaxpr_debug_info: core.JaxprDebugInfo | None, kept_var_idx: set[int]) -> None: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ad4a18757..ae88e75d5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -64,6 +64,7 @@ from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import xla_client as xc +from jax._src import sharding from jax._src.sharding_impls import ( NamedSharding, XLACompatibleSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, @@ -784,7 +785,7 @@ def pjit( The valid resource assignment specifications are: - - :py:class:`XLACompatibleSharding`, which will decide how the value + - :py:class:`Sharding`, which will decide how the value will be partitioned. With this, using a mesh context manager is not required. - :py:obj:`None` is a special case whose semantics are: @@ -906,10 +907,10 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x): + if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x): return x if mesh is None: - msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to' + msg = ('jax.jit only supports `Sharding`s being passed to' f' {name}. Looks like you are passing either `PartitionSpec` or `None`' f' which is not allowed in jax.jit.\n') if name == 'in_shardings': @@ -925,7 +926,7 @@ def _create_sharding_for_array(mesh, x, name, api_name): raise RuntimeError( f'{api_name} requires a non-empty mesh if you are passing' f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call' - f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and' + f' site? Alternatively, provide `Sharding`s to {name} and' ' then the mesh context manager is not required.') # A nice user error is raised in prepare_axis_resources. assert x is None or isinstance(x, ParsedPartitionSpec), x @@ -1206,7 +1207,7 @@ def _check_and_canonicalize_out_shardings( out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) if (is_unspecified(orig_out_shardings) or - isinstance(orig_out_shardings, XLACompatibleSharding)): + isinstance(orig_out_shardings, sharding.Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_type) else: out_shardings_flat = flatten_axis_resources( @@ -1306,7 +1307,7 @@ def pjit_check_aval_sharding( f'annotation {s}: {e}') # Use the `OpSharding` proto to find out how many ways each dimension of # the aval is sharded. This approach will work across all - # XLACompatibleSharding. + # Sharding. hlo_sharding = s._to_xla_hlo_sharding(len(shape)) assert hlo_sharding is not None num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding) @@ -1398,7 +1399,7 @@ def _resolve_in_shardings( if xla_extension_version < 270: if not isinstance(arg_s, XLACompatibleSharding): raise ValueError(f'One of the argument to pjit got sharding {arg_s} ' - 'which is not a subclass of XLACompatibleSharding.') + 'which is not a subclass of XLACompatibleSharding.') # Don't consider PmapSharding inputs as committed. They will get resharded # unconditionally. if isinstance(arg_s, PmapSharding): @@ -1901,7 +1902,7 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None) pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( - s: XLACompatibleSharding | UnspecifiedValue, + s: sharding.Sharding | UnspecifiedValue, dim: int, val: tuple[str, ...], mesh, ndim: int): if is_unspecified(s): return s diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b61694265..3824eb21e 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -158,7 +158,7 @@ def named_sharding_to_xla_hlo_sharding( @use_cpp_class(xc.NamedSharding) -class NamedSharding(XLACompatibleSharding): +class NamedSharding(sharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and @@ -303,7 +303,7 @@ def get_replicated_hlo_sharding(): @use_cpp_class(xc.SingleDeviceSharding) -class SingleDeviceSharding(XLACompatibleSharding): +class SingleDeviceSharding(sharding.Sharding): """A :class:`Sharding` that places its data on a single device. Args: @@ -382,7 +382,7 @@ def pmap_sharding_devices_indices_map( @use_cpp_class(xc.PmapSharding) -class PmapSharding(XLACompatibleSharding): +class PmapSharding(sharding.Sharding): """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec @@ -583,7 +583,7 @@ def _positional_sharding_to_xla_hlo_sharding( return xc.HloSharding.from_proto(pbuf) -class PositionalSharding(XLACompatibleSharding): +class PositionalSharding(sharding.Sharding): _devices: tuple[xc.Device, ...] _memory_kind: str | None _ids: np.ndarray # dtype DeviceIdSet @@ -690,7 +690,7 @@ class PositionalSharding(XLACompatibleSharding): def is_fully_replicated(self) -> bool: return self.shape == (1,) * self.ndim - # XLACompatibleSharding interface + # sharding.Sharding interface @property def _device_assignment(self) -> XLADeviceAssignment: @@ -734,7 +734,7 @@ class DeviceIdSet: @use_cpp_class(xc.GSPMDSharding) -class GSPMDSharding(XLACompatibleSharding): +class GSPMDSharding(sharding.Sharding): _devices: tuple[Device, ...] _hlo_sharding: xc.HloSharding _memory_kind: str | None @@ -1057,7 +1057,7 @@ def prepare_axis_resources(axis_resources, if xla_extension_version < 270: if not isinstance(entry, XLACompatibleSharding): raise ValueError(f'One of {what} got sharding {entry} which is not a ' - 'subclass of XLACompatibleSharding.') + 'subclass of XLACompatibleSharding.') new_entries.append(entry) else: new_entries.append(ParsedPartitionSpec.from_user_input( @@ -1071,7 +1071,7 @@ def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue if (is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, XLACompatibleSharding)): + isinstance(arg_axis_resources, sharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( @@ -1385,7 +1385,7 @@ def make_key_array_phys_sharding(aval, sharding): def physical_sharding( - aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: + aval, sharding: sharding.Sharding) -> sharding.Sharding: return make_key_array_phys_sharding(aval, sharding) @@ -1402,7 +1402,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding): return GSPMDSharding(phys_sharding._device_assignment, xc.HloSharding.from_proto(logical_op_sharding)) -def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): +def check_replicated_trailing_dims(sharding: sharding.Sharding, aval): if isinstance(sharding, PmapSharding): return phys_aval = core.physical_aval(aval) @@ -1415,7 +1415,7 @@ def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): f" sharding: {sharding}, partitions: {partitions}, " f"num_trailing_dims: {num_trailing_dims}") -def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: +def logical_sharding(aval, phys_sharding) -> sharding.Sharding: # The trailing dims should always be replicated. check_replicated_trailing_dims(phys_sharding, aval) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 7a0a55913..7d568d781 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -72,7 +72,7 @@ class Executable(Protocol): # TODO(frostig): improve annotation (sequences of arrays/buffers) raise NotImplementedError - def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of input shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -80,7 +80,7 @@ class Executable(Protocol): """ raise NotImplementedError - def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of output shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -218,11 +218,11 @@ class XlaExecutable(Executable): def call(self, *args_flat) -> Sequence[Any]: raise NotImplementedError("must override") - def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no input sharding information") - def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no output sharding information") @@ -503,7 +503,7 @@ class Compiled(Stage): return self._executable.runtime_executable() @property - def input_shardings(self): # PyTree[sharding.XLACompatibleSharding] + def input_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.input_shardings() # Some input shardings got DCE'd if self.in_tree.num_leaves > len(shardings_flat): @@ -513,7 +513,7 @@ class Compiled(Stage): return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error @property - def output_shardings(self): # PyTree[sharding.XLACompatibleSharding] + def output_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 9d947527b..d79b07d18 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -70,7 +70,7 @@ logger = logging.getLogger(__name__) async def create_async_array_from_callback( global_shape: array.Shape, - inp_sharding: sharding_impls.XLACompatibleSharding, + inp_sharding: jax.sharding.Sharding, data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], ): device_to_index_map = inp_sharding.devices_indices_map(global_shape) @@ -310,7 +310,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore, async def async_deserialize( - user_in_sharding: sharding_impls.XLACompatibleSharding | Layout, + user_in_sharding: jax.sharding.Sharding | Layout, tensorstore_spec: ts.Spec | dict[str, Any], global_shape: Sequence[int] | None = None, dtype=None, @@ -320,10 +320,10 @@ async def async_deserialize( ): in_sharding = (user_in_sharding.sharding if isinstance(user_in_sharding, Layout) else user_in_sharding) - if not isinstance(in_sharding, sharding_impls.XLACompatibleSharding): + if not isinstance(in_sharding, jax.sharding.Sharding): raise ValueError( 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.XLACompatibleSharding`. Got {in_sharding}') + f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') dll = (user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None) t = await ts.open( diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index ec2311736..38e0a0df3 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -133,10 +133,8 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.XLACompatibleSharding): - raise ValueError( - "Custom Partitioning rules must return XLACompatibleShardings." - ) + if not isinstance(sharding, jax.sharding.Sharding): + raise ValueError("Custom Partitioning rules must return Sharding.") return sharding._to_xla_hlo_sharding(num_dimensions) diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 1c7a7dc9c..9223c566b 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -60,7 +60,7 @@ zip = util.safe_zip DType = Any Shape = jax._src.core.Shape # The values of input and output sharding from the lowering. -LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] +LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] HloSharding = xla_client.HloSharding # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions @@ -315,12 +315,12 @@ class Exported: def xla_compatible_in_shardings( self, - mesh: sharding.Mesh) -> Sequence[sharding.XLACompatibleSharding | None]: - """Creates XLACompatibleShardings corresponding to self.in_shardings. + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to self.in_shardings. The Exported object stores `in_shardings` as HloShardings, which are independent of a mesh or set of devices. This method constructs - XLACompatibleSharding that can be used in JAX APIs such as `jax.jit` or + Sharding that can be used in JAX APIs such as `jax.jit` or `jax.device_put`. Example usage: @@ -354,8 +354,8 @@ class Exported: def xla_compatible_out_shardings( self, - mesh: sharding.Mesh) -> Sequence[sharding.XLACompatibleSharding | None]: - """Creates XLACompatibleShardings corresponding to self.out_shardings. + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to self.out_shardings. See documentation for xla_compatible_in_shardings. """ @@ -978,7 +978,7 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], def _hlo_sharding_to_xla_compatible_sharding( hlo_sharding: HloSharding | None, - mesh: sharding.Mesh) -> sharding.XLACompatibleSharding | None: + mesh: sharding.Mesh) -> sharding.Sharding | None: if hlo_sharding is None: return None return sharding_impls._gspmd_to_named_sharding_via_mesh( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index e52b8778c..b86bbf0dc 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3463,7 +3463,7 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( - s: sharding.XLACompatibleSharding, + s: sharding.Sharding, aval: core.ShapedArray) -> xla_client.HloSharding | None: if sharding_impls.is_unspecified(s): return None @@ -3515,8 +3515,8 @@ def _shard_value(val: TfVal, def _pjit(*args: TfVal, jaxpr: core.ClosedJaxpr, - in_shardings: Sequence[sharding.XLACompatibleSharding], - out_shardings: Sequence[sharding.XLACompatibleSharding], + in_shardings: Sequence[sharding.Sharding], + out_shardings: Sequence[sharding.Sharding], in_layouts, out_layouts, resource_env: mesh.ResourceEnv, donated_invars, @@ -3549,7 +3549,7 @@ tf_impl_with_avals[pjit.pjit_p] = _pjit def _pjit_sharding_constraint(arg: TfVal, *, - sharding: sharding.XLACompatibleSharding, + sharding: sharding.Sharding, resource_env: mesh.ResourceEnv, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, diff --git a/jax/sharding.py b/jax/sharding.py index 18caa9eb0..fe221f90a 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -17,7 +17,7 @@ from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( - XLACompatibleSharding as XLACompatibleSharding, + XLACompatibleSharding as _deprecated_XLACompatibleSharding, NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, @@ -28,3 +28,23 @@ from jax._src.partition_spec import ( PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh + +_deprecations = { + # Added Jun 4, 2024. + "XLACompatibleSharding": ( + ( + "jax.sharding.XLACompatibleSharding is deprecated. Use" + " jax.sharding.Sharding instead." + ), + _deprecated_XLACompatibleSharding, + ) +} + +import typing +if typing.TYPE_CHECKING: + XLACompatibleSharding = _deprecated_XLACompatibleSharding +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9c74a21ab..9c88ee623 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3157,7 +3157,7 @@ class ArrayPjitTest(jtu.JaxTestCase): mesh = jtu.create_global_mesh((1,), ('x',)) with self.assertRaisesRegex( RuntimeError, - "jax.jit only supports `XLACompatibleSharding`s being passed to " + "jax.jit only supports `Sharding`s being passed to " "in_shardings"): with mesh: jax.jit(lambda x: x, in_shardings=P('x'),