Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.

PiperOrigin-RevId: 640544939
This commit is contained in:
Yash Katariya 2024-06-05 09:06:36 -07:00 committed by jax authors
parent fc4d343c83
commit 1edd649de4
17 changed files with 120 additions and 106 deletions

View File

@ -12,6 +12,8 @@ Remember to align the itemized text with the first line of an item within a list
* JAX now requires ml_dtypes version 0.4.0 or newer.
* Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
`jax.sharding.Sharding`.
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
* from {mod}`jax.lax`: `tie_in`

View File

@ -10,9 +10,6 @@ Classes
.. autoclass:: Sharding
:members:
.. autoclass:: XLACompatibleSharding
:members:
:show-inheritance:
.. autoclass:: SingleDeviceSharding
:members:
:show-inheritance:

View File

@ -569,6 +569,7 @@ pytype_strict_library(
":partial_eval",
":path",
":pickle_util",
":sharding",
":sharding_impls",
":source_info_util",
":state_types",

View File

@ -67,8 +67,7 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
XLACompatibleSharding)
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
@ -2428,7 +2427,7 @@ def _check_sharding(aval, s):
if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
if not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
s.shard_shape(aval.shape) # should raise an Error if incompatible

View File

@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, XLACompatibleSharding,
PmapSharding, SingleDeviceSharding,
device_replica_id_map, hashed_index, num_addressable_indices) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
@ -218,9 +218,8 @@ class ArrayImpl(basearray.Array):
f"{self.aval.str_short()} with sharding {self.sharding}")
# Rearrange arrays based on the device assignment.
if isinstance(self.sharding, XLACompatibleSharding):
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
@property
def shape(self) -> Shape:

View File

@ -48,7 +48,7 @@ from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
SingleDeviceSharding, NamedSharding,
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
from jax._src.layout import Layout, DeviceLocalLayout
@ -220,7 +220,7 @@ class SourceInfo(NamedTuple):
def jaxpr_shardings(
jaxpr: core.Jaxpr,
) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]:
) -> Iterator[tuple[Sharding, SourceInfo]]:
from jax._src import pjit
from jax.experimental import shard_map
@ -241,7 +241,7 @@ def jaxpr_shardings(
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
elif eqn.primitive is device_put_p:
s = eqn.params['device']
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
if isinstance(s, Sharding) and s.memory_kind is not None:
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield (s, source_info)
for subjaxpr in core.subjaxprs(jaxpr):
@ -392,7 +392,7 @@ def _device_put_sharding_impl(x, aval, device):
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
# This has to be XLACompatible because _mcjax_reshard will run a
# XLA computation.
assert isinstance(s, XLACompatibleSharding)
assert isinstance(s, Sharding)
return _mcjax_reshard(x, s)
if not s.is_fully_addressable:
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
@ -467,11 +467,11 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
batching.defvectorized(device_put_p)
def _tpu_gpu_device_put_lowering(ctx, x, *, device, src):
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
if isinstance(device, XLACompatibleSharding):
if isinstance(device, Sharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
@ -484,7 +484,7 @@ mlir.register_lowering(
def _common_device_put_lowering(ctx, x, *, device, src):
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
raise NotImplementedError(
"Passing memory_kind to device_put via Shardings is not supported on"
@ -493,7 +493,7 @@ def _common_device_put_lowering(ctx, x, *, device, src):
mlir.register_lowering(device_put_p, _common_device_put_lowering)
def _propagate_mem_kind_dp(xm, device=None, src=None):
if isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)):
if isinstance(device, (Sharding, TransferToMemoryKind)):
return device.memory_kind
return None
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp

View File

@ -47,6 +47,7 @@ from jax._src import xla_bridge as xb
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects
@ -54,7 +55,6 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import register_jax_dialects
from jax._src.sharding_impls import XLACompatibleSharding
from jax._src.state.types import AbstractRef
map, unsafe_map = util.safe_map, map
@ -735,7 +735,7 @@ def flatten_lowering_ir_args(
_module_name_regex = re.compile(r"[^\w.-]")
def sharded_aval(aval: core.AbstractValue,
sharding: XLACompatibleSharding | None) -> core.AbstractValue:
sharding: JSharding | None) -> core.AbstractValue:
"""Returns the new aval sharded based on sharding proto."""
if sharding is None:
return aval
@ -809,11 +809,11 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
def _to_physical_op_sharding(
aval: core.AbstractValue, sharding: XLACompatibleSharding | None,
aval: core.AbstractValue, sharding: JSharding | None,
) -> xc.OpSharding | None:
if sharding is None:
return None
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
assert isinstance(sharding, JSharding)
if isinstance(aval, AbstractRef):
return _to_physical_op_sharding(aval.inner_aval, sharding)
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
@ -831,10 +831,10 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
return layout._to_xla_layout()
def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None:
def _get_mem_kind(s: JSharding | None) -> str | None:
if s is None:
return None
assert isinstance(s, sharding_impls.XLACompatibleSharding)
assert isinstance(s, JSharding)
return s.memory_kind
@ -849,8 +849,8 @@ def lower_jaxpr_to_module(
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
arg_shardings: Sequence[JSharding | None] | None = None,
result_shardings: Sequence[JSharding | None] | None = None,
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
@ -1090,8 +1090,8 @@ def lower_jaxpr_to_fun(
*,
public: bool = False,
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
arg_shardings: Sequence[JSharding | None] | None = None,
result_shardings: Sequence[JSharding | None] | None = None,
use_sharding_annotations: bool = True,
input_output_aliases: Sequence[int | None] | None = None,
xla_donated_args: Sequence[bool] | None = None,

View File

@ -66,6 +66,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
@ -114,9 +115,8 @@ def shard_arg(arg, sharding, canonicalize=True):
@profiler.annotate_function
def shard_args(
shardings: Sequence[sharding_impls.XLACompatibleSharding], args,
) -> Sequence[jax.Array]:
def shard_args(shardings: Sequence[JSharding], args
) -> Sequence[jax.Array]:
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
@ -155,7 +155,7 @@ def _shard_mutable_array(x, sharding):
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
sharding: JSharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
@ -191,7 +191,7 @@ _shard_aval_handlers[ShapedArray] = _shard_abstract_array
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
sharding: JSharding,
indices: tuple[Index, ...] | None,
) -> Callable[[list[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
@ -864,9 +864,9 @@ class UnloadedPmapExecutable:
compiled: Any
backend: xb.XlaBackend
local_input_avals: Sequence[core.AbstractValue]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
input_shardings: Sequence[JSharding]
local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
output_shardings: Sequence[JSharding]
unordered_effects: list[core.Effect]
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
@ -1096,7 +1096,7 @@ class ResultsHandler:
def local_avals_to_results_handler(
unmapped_local_out_avals: Sequence[ShapedArray],
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
local_shardings: Sequence[JSharding]) -> ResultsHandler:
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
handlers = [
@ -1108,7 +1108,7 @@ def local_avals_to_results_handler(
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
shardings: Sequence[JSharding],
committed: bool) -> ResultsHandler:
handlers = [
global_aval_to_result_handler(global_aval, s, committed)
@ -1617,8 +1617,7 @@ TilingMethod = Union[TileVectorize, TileManual]
def check_if_any_auto(
shardings: Iterable[(sharding_impls.XLACompatibleSharding |
AUTO | UnspecifiedValue)]) -> bool:
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
for s in shardings:
if is_auto(s):
return True
@ -1683,7 +1682,7 @@ class DeviceAssignmentMismatchError(Exception):
ShardingInfo = tuple[
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
Union[JSharding, UnspecifiedValue, AUTO],
MismatchType,
Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports
]
@ -1740,7 +1739,7 @@ def _get_and_check_device_assignment(
final_device_assignment = first_sharding_info[0]
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
MaybeSharding = Union[JSharding, UnspecifiedValue]
def prune_unused_inputs(
@ -1928,8 +1927,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
nreps = dispatch.jaxpr_replicas(jaxpr)
_raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)
in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
in_mlir_shardings: list[JSharding | None] | None
out_mlir_shardings: list[JSharding | None] | None
axis_ctx: mlir.AxisContext
if nreps == 1:
@ -2068,8 +2067,7 @@ class AllArgsInfo(NamedTuple):
@lru_cache(maxsize=2048)
def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding,
ndim: int) -> GSPMDSharding:
def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
@ -2242,11 +2240,11 @@ def lower_sharding_computation(
def _to_logical_sharding(
aval: core.AbstractValue, sharding: MaybeSharding | AUTO
) -> sharding_impls.XLACompatibleSharding | None:
) -> JSharding | None:
if is_unspecified(sharding) or is_auto(sharding):
return None
elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)):
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
assert isinstance(sharding, JSharding)
return sharding
elif isinstance(aval, core.AbstractToken):
return None
@ -2339,8 +2337,8 @@ def lower_mesh_computation(
# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
in_partitions: list[sharding_impls.XLACompatibleSharding | None] | None
out_partitions: list[sharding_impls.XLACompatibleSharding | None] | None
in_partitions: list[JSharding | None] | None
out_partitions: list[JSharding | None] | None
axis_ctx: mlir.AxisContext
if spmd_lowering:
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
@ -2554,7 +2552,7 @@ def _get_mesh_pspec_shardings_from_executable(
_orig_out_sharding_handlers = {}
_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
_ShardingT = TypeVar("_ShardingT", bound=JSharding)
def _register_out_sharding_handler(
@ -2849,9 +2847,9 @@ class UnloadedMeshExecutable:
device_assignment: xc.DeviceList | Sequence[xc.Device]
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
input_shardings: Sequence[JSharding]
output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
output_shardings: Sequence[JSharding]
committed: bool
name: str
unordered_effects: list[core.Effect]
@ -2891,9 +2889,8 @@ class UnloadedMeshExecutable:
hlo: ir.Module,
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
UnspecifiedValue)],
in_shardings: Sequence[JSharding | AUTO],
out_shardings: Sequence[(JSharding | AUTO | UnspecifiedValue)],
spmd_lowering: bool,
tuple_args: bool,
auto_spmd_lowering: bool,
@ -3000,8 +2997,8 @@ class UnloadedMeshExecutable:
class MeshExecutableFastpathData(NamedTuple):
xla_executable: xc.LoadedExecutable
out_pytree_def: Any
in_shardings: Sequence[sharding_impls.XLACompatibleSharding]
out_shardings: Sequence[sharding_impls.XLACompatibleSharding]
in_shardings: Sequence[JSharding]
out_shardings: Sequence[JSharding]
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
@ -3077,10 +3074,10 @@ class MeshExecutable(stages.XlaExecutable):
self._kept_var_idx)
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
def input_shardings(self) -> Sequence[JSharding]:
return self._in_shardings
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
def output_shardings(self) -> Sequence[JSharding]:
return self._out_shardings
def input_layouts(self):
@ -3190,7 +3187,7 @@ def check_device_backend_on_shardings(shardings) -> bool:
def check_array_xla_sharding_layout_match(
args_after_dce,
in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
in_xla_shardings: Sequence[JSharding],
in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None,
kept_var_idx: set[int]) -> None:

View File

@ -64,6 +64,7 @@ from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src import sharding
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
@ -784,7 +785,7 @@ def pjit(
The valid resource assignment specifications are:
- :py:class:`XLACompatibleSharding`, which will decide how the value
- :py:class:`Sharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None` is a special case whose semantics are:
@ -906,10 +907,10 @@ def hashable_pytree(pytree):
def _create_sharding_for_array(mesh, x, name, api_name):
if x is None and (mesh is None or mesh.empty):
return UNSPECIFIED
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x):
return x
if mesh is None:
msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to'
msg = ('jax.jit only supports `Sharding`s being passed to'
f' {name}. Looks like you are passing either `PartitionSpec` or `None`'
f' which is not allowed in jax.jit.\n')
if name == 'in_shardings':
@ -925,7 +926,7 @@ def _create_sharding_for_array(mesh, x, name, api_name):
raise RuntimeError(
f'{api_name} requires a non-empty mesh if you are passing'
f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call'
f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and'
f' site? Alternatively, provide `Sharding`s to {name} and'
' then the mesh context manager is not required.')
# A nice user error is raised in prepare_axis_resources.
assert x is None or isinstance(x, ParsedPartitionSpec), x
@ -1206,7 +1207,7 @@ def _check_and_canonicalize_out_shardings(
out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if (is_unspecified(orig_out_shardings) or
isinstance(orig_out_shardings, XLACompatibleSharding)):
isinstance(orig_out_shardings, sharding.Sharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_type)
else:
out_shardings_flat = flatten_axis_resources(
@ -1306,7 +1307,7 @@ def pjit_check_aval_sharding(
f'annotation {s}: {e}')
# Use the `OpSharding` proto to find out how many ways each dimension of
# the aval is sharded. This approach will work across all
# XLACompatibleSharding.
# Sharding.
hlo_sharding = s._to_xla_hlo_sharding(len(shape))
assert hlo_sharding is not None
num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
@ -1398,7 +1399,7 @@ def _resolve_in_shardings(
if xla_extension_version < 270:
if not isinstance(arg_s, XLACompatibleSharding):
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
'which is not a subclass of XLACompatibleSharding.')
'which is not a subclass of XLACompatibleSharding.')
# Don't consider PmapSharding inputs as committed. They will get resharded
# unconditionally.
if isinstance(arg_s, PmapSharding):
@ -1901,7 +1902,7 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
def _pjit_batcher_for_sharding(
s: XLACompatibleSharding | UnspecifiedValue,
s: sharding.Sharding | UnspecifiedValue,
dim: int, val: tuple[str, ...], mesh, ndim: int):
if is_unspecified(s):
return s

View File

@ -158,7 +158,7 @@ def named_sharding_to_xla_hlo_sharding(
@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
class NamedSharding(sharding.Sharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
@ -303,7 +303,7 @@ def get_replicated_hlo_sharding():
@use_cpp_class(xc.SingleDeviceSharding)
class SingleDeviceSharding(XLACompatibleSharding):
class SingleDeviceSharding(sharding.Sharding):
"""A :class:`Sharding` that places its data on a single device.
Args:
@ -382,7 +382,7 @@ def pmap_sharding_devices_indices_map(
@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):
class PmapSharding(sharding.Sharding):
"""Describes a sharding used by :func:`jax.pmap`."""
devices: np.ndarray
sharding_spec: sharding_specs.ShardingSpec
@ -583,7 +583,7 @@ def _positional_sharding_to_xla_hlo_sharding(
return xc.HloSharding.from_proto(pbuf)
class PositionalSharding(XLACompatibleSharding):
class PositionalSharding(sharding.Sharding):
_devices: tuple[xc.Device, ...]
_memory_kind: str | None
_ids: np.ndarray # dtype DeviceIdSet
@ -690,7 +690,7 @@ class PositionalSharding(XLACompatibleSharding):
def is_fully_replicated(self) -> bool:
return self.shape == (1,) * self.ndim
# XLACompatibleSharding interface
# sharding.Sharding interface
@property
def _device_assignment(self) -> XLADeviceAssignment:
@ -734,7 +734,7 @@ class DeviceIdSet:
@use_cpp_class(xc.GSPMDSharding)
class GSPMDSharding(XLACompatibleSharding):
class GSPMDSharding(sharding.Sharding):
_devices: tuple[Device, ...]
_hlo_sharding: xc.HloSharding
_memory_kind: str | None
@ -1057,7 +1057,7 @@ def prepare_axis_resources(axis_resources,
if xla_extension_version < 270:
if not isinstance(entry, XLACompatibleSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not a '
'subclass of XLACompatibleSharding.')
'subclass of XLACompatibleSharding.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(
@ -1071,7 +1071,7 @@ def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if (is_unspecified_or_auto(arg_axis_resources) or
isinstance(arg_axis_resources, XLACompatibleSharding)):
isinstance(arg_axis_resources, sharding.Sharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = collections.Counter(
@ -1385,7 +1385,7 @@ def make_key_array_phys_sharding(aval, sharding):
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
aval, sharding: sharding.Sharding) -> sharding.Sharding:
return make_key_array_phys_sharding(aval, sharding)
@ -1402,7 +1402,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
if isinstance(sharding, PmapSharding):
return
phys_aval = core.physical_aval(aval)
@ -1415,7 +1415,7 @@ def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
f" sharding: {sharding}, partitions: {partitions}, "
f"num_trailing_dims: {num_trailing_dims}")
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
def logical_sharding(aval, phys_sharding) -> sharding.Sharding:
# The trailing dims should always be replicated.
check_replicated_trailing_dims(phys_sharding, aval)

View File

@ -72,7 +72,7 @@ class Executable(Protocol):
# TODO(frostig): improve annotation (sequences of arrays/buffers)
raise NotImplementedError
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
def input_shardings(self) -> Sequence[jax.sharding.Sharding]:
"""Flat sequence of input shardings.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
@ -80,7 +80,7 @@ class Executable(Protocol):
"""
raise NotImplementedError
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
def output_shardings(self) -> Sequence[jax.sharding.Sharding]:
"""Flat sequence of output shardings.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
@ -218,11 +218,11 @@ class XlaExecutable(Executable):
def call(self, *args_flat) -> Sequence[Any]:
raise NotImplementedError("must override")
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
def input_shardings(self) -> Sequence[jax.sharding.Sharding]:
raise NotImplementedError(
"compiled executable carries no input sharding information")
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
def output_shardings(self) -> Sequence[jax.sharding.Sharding]:
raise NotImplementedError(
"compiled executable carries no output sharding information")
@ -503,7 +503,7 @@ class Compiled(Stage):
return self._executable.runtime_executable()
@property
def input_shardings(self): # PyTree[sharding.XLACompatibleSharding]
def input_shardings(self): # PyTree[sharding.Sharding]
shardings_flat = self._executable.input_shardings()
# Some input shardings got DCE'd
if self.in_tree.num_leaves > len(shardings_flat):
@ -513,7 +513,7 @@ class Compiled(Stage):
return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error
@property
def output_shardings(self): # PyTree[sharding.XLACompatibleSharding]
def output_shardings(self): # PyTree[sharding.Sharding]
shardings_flat = self._executable.output_shardings()
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error

View File

@ -70,7 +70,7 @@ logger = logging.getLogger(__name__)
async def create_async_array_from_callback(
global_shape: array.Shape,
inp_sharding: sharding_impls.XLACompatibleSharding,
inp_sharding: jax.sharding.Sharding,
data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
):
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
@ -310,7 +310,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
async def async_deserialize(
user_in_sharding: sharding_impls.XLACompatibleSharding | Layout,
user_in_sharding: jax.sharding.Sharding | Layout,
tensorstore_spec: ts.Spec | dict[str, Any],
global_shape: Sequence[int] | None = None,
dtype=None,
@ -320,10 +320,10 @@ async def async_deserialize(
):
in_sharding = (user_in_sharding.sharding
if isinstance(user_in_sharding, Layout) else user_in_sharding)
if not isinstance(in_sharding, sharding_impls.XLACompatibleSharding):
if not isinstance(in_sharding, jax.sharding.Sharding):
raise ValueError(
'sharding passed to deserialization should be specified, concrete and'
f' an instance of `jax.XLACompatibleSharding`. Got {in_sharding}')
f' an instance of `jax.sharding.Sharding`. Got {in_sharding}')
dll = (user_in_sharding.device_local_layout
if isinstance(user_in_sharding, Layout) else None)
t = await ts.open(

View File

@ -133,10 +133,8 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape,
def _to_hlo_sharding(sharding, num_dimensions):
if not isinstance(sharding, jax.sharding.XLACompatibleSharding):
raise ValueError(
"Custom Partitioning rules must return XLACompatibleShardings."
)
if not isinstance(sharding, jax.sharding.Sharding):
raise ValueError("Custom Partitioning rules must return Sharding.")
return sharding._to_xla_hlo_sharding(num_dimensions)

View File

@ -60,7 +60,7 @@ zip = util.safe_zip
DType = Any
Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
HloSharding = xla_client.HloSharding
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
@ -315,12 +315,12 @@ class Exported:
def xla_compatible_in_shardings(
self,
mesh: sharding.Mesh) -> Sequence[sharding.XLACompatibleSharding | None]:
"""Creates XLACompatibleShardings corresponding to self.in_shardings.
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.in_shardings.
The Exported object stores `in_shardings` as HloShardings, which are
independent of a mesh or set of devices. This method constructs
XLACompatibleSharding that can be used in JAX APIs such as `jax.jit` or
Sharding that can be used in JAX APIs such as `jax.jit` or
`jax.device_put`.
Example usage:
@ -354,8 +354,8 @@ class Exported:
def xla_compatible_out_shardings(
self,
mesh: sharding.Mesh) -> Sequence[sharding.XLACompatibleSharding | None]:
"""Creates XLACompatibleShardings corresponding to self.out_shardings.
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.out_shardings.
See documentation for xla_compatible_in_shardings.
"""
@ -978,7 +978,7 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
def _hlo_sharding_to_xla_compatible_sharding(
hlo_sharding: HloSharding | None,
mesh: sharding.Mesh) -> sharding.XLACompatibleSharding | None:
mesh: sharding.Mesh) -> sharding.Sharding | None:
if hlo_sharding is None:
return None
return sharding_impls._gspmd_to_named_sharding_via_mesh(

View File

@ -3463,7 +3463,7 @@ def split_to_logical_devices(tensor: TfVal,
def _xla_compatible_sharding_to_hlo_sharding(
s: sharding.XLACompatibleSharding,
s: sharding.Sharding,
aval: core.ShapedArray) -> xla_client.HloSharding | None:
if sharding_impls.is_unspecified(s):
return None
@ -3515,8 +3515,8 @@ def _shard_value(val: TfVal,
def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_shardings: Sequence[sharding.XLACompatibleSharding],
out_shardings: Sequence[sharding.XLACompatibleSharding],
in_shardings: Sequence[sharding.Sharding],
out_shardings: Sequence[sharding.Sharding],
in_layouts, out_layouts,
resource_env: mesh.ResourceEnv,
donated_invars,
@ -3549,7 +3549,7 @@ tf_impl_with_avals[pjit.pjit_p] = _pjit
def _pjit_sharding_constraint(arg: TfVal, *,
sharding: sharding.XLACompatibleSharding,
sharding: sharding.Sharding,
resource_env: mesh.ResourceEnv,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,

View File

@ -17,7 +17,7 @@
from jax._src.sharding import Sharding as Sharding
from jax._src.sharding_impls import (
XLACompatibleSharding as XLACompatibleSharding,
XLACompatibleSharding as _deprecated_XLACompatibleSharding,
NamedSharding as NamedSharding,
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
@ -28,3 +28,23 @@ from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
)
from jax._src.interpreters.pxla import Mesh as Mesh
_deprecations = {
# Added Jun 4, 2024.
"XLACompatibleSharding": (
(
"jax.sharding.XLACompatibleSharding is deprecated. Use"
" jax.sharding.Sharding instead."
),
_deprecated_XLACompatibleSharding,
)
}
import typing
if typing.TYPE_CHECKING:
XLACompatibleSharding = _deprecated_XLACompatibleSharding
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

View File

@ -3157,7 +3157,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((1,), ('x',))
with self.assertRaisesRegex(
RuntimeError,
"jax.jit only supports `XLACompatibleSharding`s being passed to "
"jax.jit only supports `Sharding`s being passed to "
"in_shardings"):
with mesh:
jax.jit(lambda x: x, in_shardings=P('x'),