mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Deprecate XLACompatibleSharding
in favor of jax.sharding.Sharding
.
PiperOrigin-RevId: 640544939
This commit is contained in:
parent
fc4d343c83
commit
1edd649de4
@ -12,6 +12,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* JAX now requires ml_dtypes version 0.4.0 or newer.
|
||||
|
||||
* Deprecations
|
||||
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
|
||||
`jax.sharding.Sharding`.
|
||||
* Removed a number of previously-deprecated APIs:
|
||||
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
|
||||
* from {mod}`jax.lax`: `tie_in`
|
||||
|
@ -10,9 +10,6 @@ Classes
|
||||
|
||||
.. autoclass:: Sharding
|
||||
:members:
|
||||
.. autoclass:: XLACompatibleSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: SingleDeviceSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
@ -569,6 +569,7 @@ pytype_strict_library(
|
||||
":partial_eval",
|
||||
":path",
|
||||
":pickle_util",
|
||||
":sharding",
|
||||
":sharding_impls",
|
||||
":source_info_util",
|
||||
":state_types",
|
||||
|
@ -67,8 +67,7 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
|
||||
XLACompatibleSharding)
|
||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.layout import Layout, AutoLayout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
@ -2428,7 +2427,7 @@ def _check_sharding(aval, s):
|
||||
if isinstance(s, Sharding):
|
||||
if isinstance(aval, core.AbstractToken):
|
||||
aval = core.token_shaped_array
|
||||
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
|
||||
if not isinstance(s, PmapSharding):
|
||||
pjit.pjit_check_aval_sharding(
|
||||
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
|
||||
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
||||
|
@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, XLACompatibleSharding,
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
device_replica_id_map, hashed_index, num_addressable_indices) # pyformat: disable
|
||||
from jax._src.typing import ArrayLike, DLDeviceType
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
@ -218,9 +218,8 @@ class ArrayImpl(basearray.Array):
|
||||
f"{self.aval.str_short()} with sharding {self.sharding}")
|
||||
|
||||
# Rearrange arrays based on the device assignment.
|
||||
if isinstance(self.sharding, XLACompatibleSharding):
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape:
|
||||
|
@ -48,7 +48,7 @@ from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
|
||||
SingleDeviceSharding, NamedSharding,
|
||||
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
|
||||
from jax._src.layout import Layout, DeviceLocalLayout
|
||||
|
||||
@ -220,7 +220,7 @@ class SourceInfo(NamedTuple):
|
||||
|
||||
def jaxpr_shardings(
|
||||
jaxpr: core.Jaxpr,
|
||||
) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]:
|
||||
) -> Iterator[tuple[Sharding, SourceInfo]]:
|
||||
from jax._src import pjit
|
||||
from jax.experimental import shard_map
|
||||
|
||||
@ -241,7 +241,7 @@ def jaxpr_shardings(
|
||||
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
||||
elif eqn.primitive is device_put_p:
|
||||
s = eqn.params['device']
|
||||
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
|
||||
if isinstance(s, Sharding) and s.memory_kind is not None:
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield (s, source_info)
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
@ -392,7 +392,7 @@ def _device_put_sharding_impl(x, aval, device):
|
||||
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
|
||||
# This has to be XLACompatible because _mcjax_reshard will run a
|
||||
# XLA computation.
|
||||
assert isinstance(s, XLACompatibleSharding)
|
||||
assert isinstance(s, Sharding)
|
||||
return _mcjax_reshard(x, s)
|
||||
if not s.is_fully_addressable:
|
||||
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
|
||||
@ -467,11 +467,11 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
|
||||
batching.defvectorized(device_put_p)
|
||||
|
||||
def _tpu_gpu_device_put_lowering(ctx, x, *, device, src):
|
||||
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
if isinstance(device, XLACompatibleSharding):
|
||||
if isinstance(device, Sharding):
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
||||
@ -484,7 +484,7 @@ mlir.register_lowering(
|
||||
|
||||
|
||||
def _common_device_put_lowering(ctx, x, *, device, src):
|
||||
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
raise NotImplementedError(
|
||||
"Passing memory_kind to device_put via Shardings is not supported on"
|
||||
@ -493,7 +493,7 @@ def _common_device_put_lowering(ctx, x, *, device, src):
|
||||
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|
||||
|
||||
def _propagate_mem_kind_dp(xm, device=None, src=None):
|
||||
if isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)):
|
||||
if isinstance(device, (Sharding, TransferToMemoryKind)):
|
||||
return device.memory_kind
|
||||
return None
|
||||
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp
|
||||
|
@ -47,6 +47,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import dialects
|
||||
@ -54,7 +55,6 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import register_jax_dialects
|
||||
from jax._src.sharding_impls import XLACompatibleSharding
|
||||
from jax._src.state.types import AbstractRef
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -735,7 +735,7 @@ def flatten_lowering_ir_args(
|
||||
_module_name_regex = re.compile(r"[^\w.-]")
|
||||
|
||||
def sharded_aval(aval: core.AbstractValue,
|
||||
sharding: XLACompatibleSharding | None) -> core.AbstractValue:
|
||||
sharding: JSharding | None) -> core.AbstractValue:
|
||||
"""Returns the new aval sharded based on sharding proto."""
|
||||
if sharding is None:
|
||||
return aval
|
||||
@ -809,11 +809,11 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
||||
|
||||
|
||||
def _to_physical_op_sharding(
|
||||
aval: core.AbstractValue, sharding: XLACompatibleSharding | None,
|
||||
aval: core.AbstractValue, sharding: JSharding | None,
|
||||
) -> xc.OpSharding | None:
|
||||
if sharding is None:
|
||||
return None
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
assert isinstance(sharding, JSharding)
|
||||
if isinstance(aval, AbstractRef):
|
||||
return _to_physical_op_sharding(aval.inner_aval, sharding)
|
||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||
@ -831,10 +831,10 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
|
||||
return layout._to_xla_layout()
|
||||
|
||||
|
||||
def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None:
|
||||
def _get_mem_kind(s: JSharding | None) -> str | None:
|
||||
if s is None:
|
||||
return None
|
||||
assert isinstance(s, sharding_impls.XLACompatibleSharding)
|
||||
assert isinstance(s, JSharding)
|
||||
return s.memory_kind
|
||||
|
||||
|
||||
@ -849,8 +849,8 @@ def lower_jaxpr_to_module(
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
arg_shardings: Sequence[JSharding | None] | None = None,
|
||||
result_shardings: Sequence[JSharding | None] | None = None,
|
||||
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
arg_names: Sequence[str | None] | None = None,
|
||||
@ -1090,8 +1090,8 @@ def lower_jaxpr_to_fun(
|
||||
*,
|
||||
public: bool = False,
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
arg_shardings: Sequence[JSharding | None] | None = None,
|
||||
result_shardings: Sequence[JSharding | None] | None = None,
|
||||
use_sharding_annotations: bool = True,
|
||||
input_output_aliases: Sequence[int | None] | None = None,
|
||||
xla_donated_args: Sequence[bool] | None = None,
|
||||
|
@ -66,6 +66,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
||||
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
@ -114,9 +115,8 @@ def shard_arg(arg, sharding, canonicalize=True):
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def shard_args(
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding], args,
|
||||
) -> Sequence[jax.Array]:
|
||||
def shard_args(shardings: Sequence[JSharding], args
|
||||
) -> Sequence[jax.Array]:
|
||||
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]
|
||||
|
||||
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
|
||||
@ -155,7 +155,7 @@ def _shard_mutable_array(x, sharding):
|
||||
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
|
||||
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
||||
sharding: JSharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
from jax._src import array
|
||||
|
||||
@ -191,7 +191,7 @@ _shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding: sharding_impls.XLACompatibleSharding,
|
||||
sharding: JSharding,
|
||||
indices: tuple[Index, ...] | None,
|
||||
) -> Callable[[list[xc.ArrayImpl]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
@ -864,9 +864,9 @@ class UnloadedPmapExecutable:
|
||||
compiled: Any
|
||||
backend: xb.XlaBackend
|
||||
local_input_avals: Sequence[core.AbstractValue]
|
||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
input_shardings: Sequence[JSharding]
|
||||
local_output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
output_shardings: Sequence[JSharding]
|
||||
unordered_effects: list[core.Effect]
|
||||
ordered_effects: list[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
@ -1096,7 +1096,7 @@ class ResultsHandler:
|
||||
|
||||
def local_avals_to_results_handler(
|
||||
unmapped_local_out_avals: Sequence[ShapedArray],
|
||||
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
|
||||
local_shardings: Sequence[JSharding]) -> ResultsHandler:
|
||||
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
||||
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
||||
handlers = [
|
||||
@ -1108,7 +1108,7 @@ def local_avals_to_results_handler(
|
||||
|
||||
def global_avals_to_results_handler(
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
shardings: Sequence[JSharding],
|
||||
committed: bool) -> ResultsHandler:
|
||||
handlers = [
|
||||
global_aval_to_result_handler(global_aval, s, committed)
|
||||
@ -1617,8 +1617,7 @@ TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
|
||||
def check_if_any_auto(
|
||||
shardings: Iterable[(sharding_impls.XLACompatibleSharding |
|
||||
AUTO | UnspecifiedValue)]) -> bool:
|
||||
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
|
||||
for s in shardings:
|
||||
if is_auto(s):
|
||||
return True
|
||||
@ -1683,7 +1682,7 @@ class DeviceAssignmentMismatchError(Exception):
|
||||
|
||||
|
||||
ShardingInfo = tuple[
|
||||
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
|
||||
Union[JSharding, UnspecifiedValue, AUTO],
|
||||
MismatchType,
|
||||
Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports
|
||||
]
|
||||
@ -1740,7 +1739,7 @@ def _get_and_check_device_assignment(
|
||||
final_device_assignment = first_sharding_info[0]
|
||||
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
||||
|
||||
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
|
||||
MaybeSharding = Union[JSharding, UnspecifiedValue]
|
||||
|
||||
|
||||
def prune_unused_inputs(
|
||||
@ -1928,8 +1927,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
nreps = dispatch.jaxpr_replicas(jaxpr)
|
||||
_raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)
|
||||
|
||||
in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
|
||||
out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
|
||||
in_mlir_shardings: list[JSharding | None] | None
|
||||
out_mlir_shardings: list[JSharding | None] | None
|
||||
axis_ctx: mlir.AxisContext
|
||||
|
||||
if nreps == 1:
|
||||
@ -2068,8 +2067,7 @@ class AllArgsInfo(NamedTuple):
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding,
|
||||
ndim: int) -> GSPMDSharding:
|
||||
def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
@ -2242,11 +2240,11 @@ def lower_sharding_computation(
|
||||
|
||||
def _to_logical_sharding(
|
||||
aval: core.AbstractValue, sharding: MaybeSharding | AUTO
|
||||
) -> sharding_impls.XLACompatibleSharding | None:
|
||||
) -> JSharding | None:
|
||||
if is_unspecified(sharding) or is_auto(sharding):
|
||||
return None
|
||||
elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)):
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
assert isinstance(sharding, JSharding)
|
||||
return sharding
|
||||
elif isinstance(aval, core.AbstractToken):
|
||||
return None
|
||||
@ -2339,8 +2337,8 @@ def lower_mesh_computation(
|
||||
# 2. Build up the HLO
|
||||
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
|
||||
|
||||
in_partitions: list[sharding_impls.XLACompatibleSharding | None] | None
|
||||
out_partitions: list[sharding_impls.XLACompatibleSharding | None] | None
|
||||
in_partitions: list[JSharding | None] | None
|
||||
out_partitions: list[JSharding | None] | None
|
||||
axis_ctx: mlir.AxisContext
|
||||
if spmd_lowering:
|
||||
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||
@ -2554,7 +2552,7 @@ def _get_mesh_pspec_shardings_from_executable(
|
||||
|
||||
_orig_out_sharding_handlers = {}
|
||||
|
||||
_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
|
||||
_ShardingT = TypeVar("_ShardingT", bound=JSharding)
|
||||
|
||||
|
||||
def _register_out_sharding_handler(
|
||||
@ -2849,9 +2847,9 @@ class UnloadedMeshExecutable:
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device]
|
||||
backend: xb.XlaBackend
|
||||
input_avals: Sequence[ShapedArray]
|
||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
input_shardings: Sequence[JSharding]
|
||||
output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
output_shardings: Sequence[JSharding]
|
||||
committed: bool
|
||||
name: str
|
||||
unordered_effects: list[core.Effect]
|
||||
@ -2891,9 +2889,8 @@ class UnloadedMeshExecutable:
|
||||
hlo: ir.Module,
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
|
||||
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
|
||||
UnspecifiedValue)],
|
||||
in_shardings: Sequence[JSharding | AUTO],
|
||||
out_shardings: Sequence[(JSharding | AUTO | UnspecifiedValue)],
|
||||
spmd_lowering: bool,
|
||||
tuple_args: bool,
|
||||
auto_spmd_lowering: bool,
|
||||
@ -3000,8 +2997,8 @@ class UnloadedMeshExecutable:
|
||||
class MeshExecutableFastpathData(NamedTuple):
|
||||
xla_executable: xc.LoadedExecutable
|
||||
out_pytree_def: Any
|
||||
in_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
out_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
in_shardings: Sequence[JSharding]
|
||||
out_shardings: Sequence[JSharding]
|
||||
out_avals: Sequence[ShapedArray]
|
||||
out_committed: Sequence[bool]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
@ -3077,10 +3074,10 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._kept_var_idx)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
def input_shardings(self) -> Sequence[JSharding]:
|
||||
return self._in_shardings
|
||||
|
||||
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
def output_shardings(self) -> Sequence[JSharding]:
|
||||
return self._out_shardings
|
||||
|
||||
def input_layouts(self):
|
||||
@ -3190,7 +3187,7 @@ def check_device_backend_on_shardings(shardings) -> bool:
|
||||
|
||||
def check_array_xla_sharding_layout_match(
|
||||
args_after_dce,
|
||||
in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
in_xla_shardings: Sequence[JSharding],
|
||||
in_xla_layouts: Sequence[DeviceLocalLayout],
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None,
|
||||
kept_var_idx: set[int]) -> None:
|
||||
|
@ -64,6 +64,7 @@ from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
@ -784,7 +785,7 @@ def pjit(
|
||||
|
||||
The valid resource assignment specifications are:
|
||||
|
||||
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
||||
- :py:class:`Sharding`, which will decide how the value
|
||||
will be partitioned. With this, using a mesh context manager is not
|
||||
required.
|
||||
- :py:obj:`None` is a special case whose semantics are:
|
||||
@ -906,10 +907,10 @@ def hashable_pytree(pytree):
|
||||
def _create_sharding_for_array(mesh, x, name, api_name):
|
||||
if x is None and (mesh is None or mesh.empty):
|
||||
return UNSPECIFIED
|
||||
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
|
||||
if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x):
|
||||
return x
|
||||
if mesh is None:
|
||||
msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to'
|
||||
msg = ('jax.jit only supports `Sharding`s being passed to'
|
||||
f' {name}. Looks like you are passing either `PartitionSpec` or `None`'
|
||||
f' which is not allowed in jax.jit.\n')
|
||||
if name == 'in_shardings':
|
||||
@ -925,7 +926,7 @@ def _create_sharding_for_array(mesh, x, name, api_name):
|
||||
raise RuntimeError(
|
||||
f'{api_name} requires a non-empty mesh if you are passing'
|
||||
f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call'
|
||||
f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and'
|
||||
f' site? Alternatively, provide `Sharding`s to {name} and'
|
||||
' then the mesh context manager is not required.')
|
||||
# A nice user error is raised in prepare_axis_resources.
|
||||
assert x is None or isinstance(x, ParsedPartitionSpec), x
|
||||
@ -1206,7 +1207,7 @@ def _check_and_canonicalize_out_shardings(
|
||||
out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set):
|
||||
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
||||
if (is_unspecified(orig_out_shardings) or
|
||||
isinstance(orig_out_shardings, XLACompatibleSharding)):
|
||||
isinstance(orig_out_shardings, sharding.Sharding)):
|
||||
out_shardings_flat = (orig_out_shardings,) * len(out_type)
|
||||
else:
|
||||
out_shardings_flat = flatten_axis_resources(
|
||||
@ -1306,7 +1307,7 @@ def pjit_check_aval_sharding(
|
||||
f'annotation {s}: {e}')
|
||||
# Use the `OpSharding` proto to find out how many ways each dimension of
|
||||
# the aval is sharded. This approach will work across all
|
||||
# XLACompatibleSharding.
|
||||
# Sharding.
|
||||
hlo_sharding = s._to_xla_hlo_sharding(len(shape))
|
||||
assert hlo_sharding is not None
|
||||
num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
|
||||
@ -1398,7 +1399,7 @@ def _resolve_in_shardings(
|
||||
if xla_extension_version < 270:
|
||||
if not isinstance(arg_s, XLACompatibleSharding):
|
||||
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
|
||||
'which is not a subclass of XLACompatibleSharding.')
|
||||
'which is not a subclass of XLACompatibleSharding.')
|
||||
# Don't consider PmapSharding inputs as committed. They will get resharded
|
||||
# unconditionally.
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
@ -1901,7 +1902,7 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
|
||||
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: XLACompatibleSharding | UnspecifiedValue,
|
||||
s: sharding.Sharding | UnspecifiedValue,
|
||||
dim: int, val: tuple[str, ...], mesh, ndim: int):
|
||||
if is_unspecified(s):
|
||||
return s
|
||||
|
@ -158,7 +158,7 @@ def named_sharding_to_xla_hlo_sharding(
|
||||
|
||||
|
||||
@use_cpp_class(xc.NamedSharding)
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
class NamedSharding(sharding.Sharding):
|
||||
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
||||
|
||||
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
|
||||
@ -303,7 +303,7 @@ def get_replicated_hlo_sharding():
|
||||
|
||||
|
||||
@use_cpp_class(xc.SingleDeviceSharding)
|
||||
class SingleDeviceSharding(XLACompatibleSharding):
|
||||
class SingleDeviceSharding(sharding.Sharding):
|
||||
"""A :class:`Sharding` that places its data on a single device.
|
||||
|
||||
Args:
|
||||
@ -382,7 +382,7 @@ def pmap_sharding_devices_indices_map(
|
||||
|
||||
|
||||
@use_cpp_class(xc.PmapSharding)
|
||||
class PmapSharding(XLACompatibleSharding):
|
||||
class PmapSharding(sharding.Sharding):
|
||||
"""Describes a sharding used by :func:`jax.pmap`."""
|
||||
devices: np.ndarray
|
||||
sharding_spec: sharding_specs.ShardingSpec
|
||||
@ -583,7 +583,7 @@ def _positional_sharding_to_xla_hlo_sharding(
|
||||
return xc.HloSharding.from_proto(pbuf)
|
||||
|
||||
|
||||
class PositionalSharding(XLACompatibleSharding):
|
||||
class PositionalSharding(sharding.Sharding):
|
||||
_devices: tuple[xc.Device, ...]
|
||||
_memory_kind: str | None
|
||||
_ids: np.ndarray # dtype DeviceIdSet
|
||||
@ -690,7 +690,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == (1,) * self.ndim
|
||||
|
||||
# XLACompatibleSharding interface
|
||||
# sharding.Sharding interface
|
||||
|
||||
@property
|
||||
def _device_assignment(self) -> XLADeviceAssignment:
|
||||
@ -734,7 +734,7 @@ class DeviceIdSet:
|
||||
|
||||
|
||||
@use_cpp_class(xc.GSPMDSharding)
|
||||
class GSPMDSharding(XLACompatibleSharding):
|
||||
class GSPMDSharding(sharding.Sharding):
|
||||
_devices: tuple[Device, ...]
|
||||
_hlo_sharding: xc.HloSharding
|
||||
_memory_kind: str | None
|
||||
@ -1057,7 +1057,7 @@ def prepare_axis_resources(axis_resources,
|
||||
if xla_extension_version < 270:
|
||||
if not isinstance(entry, XLACompatibleSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not a '
|
||||
'subclass of XLACompatibleSharding.')
|
||||
'subclass of XLACompatibleSharding.')
|
||||
new_entries.append(entry)
|
||||
else:
|
||||
new_entries.append(ParsedPartitionSpec.from_user_input(
|
||||
@ -1071,7 +1071,7 @@ def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if (is_unspecified_or_auto(arg_axis_resources) or
|
||||
isinstance(arg_axis_resources, XLACompatibleSharding)):
|
||||
isinstance(arg_axis_resources, sharding.Sharding)):
|
||||
continue
|
||||
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
||||
resource_counts = collections.Counter(
|
||||
@ -1385,7 +1385,7 @@ def make_key_array_phys_sharding(aval, sharding):
|
||||
|
||||
|
||||
def physical_sharding(
|
||||
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
|
||||
aval, sharding: sharding.Sharding) -> sharding.Sharding:
|
||||
return make_key_array_phys_sharding(aval, sharding)
|
||||
|
||||
|
||||
@ -1402,7 +1402,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
|
||||
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
|
||||
def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
|
||||
if isinstance(sharding, PmapSharding):
|
||||
return
|
||||
phys_aval = core.physical_aval(aval)
|
||||
@ -1415,7 +1415,7 @@ def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
|
||||
f" sharding: {sharding}, partitions: {partitions}, "
|
||||
f"num_trailing_dims: {num_trailing_dims}")
|
||||
|
||||
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
|
||||
def logical_sharding(aval, phys_sharding) -> sharding.Sharding:
|
||||
# The trailing dims should always be replicated.
|
||||
check_replicated_trailing_dims(phys_sharding, aval)
|
||||
|
||||
|
@ -72,7 +72,7 @@ class Executable(Protocol):
|
||||
# TODO(frostig): improve annotation (sequences of arrays/buffers)
|
||||
raise NotImplementedError
|
||||
|
||||
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||||
def input_shardings(self) -> Sequence[jax.sharding.Sharding]:
|
||||
"""Flat sequence of input shardings.
|
||||
|
||||
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||||
@ -80,7 +80,7 @@ class Executable(Protocol):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||||
def output_shardings(self) -> Sequence[jax.sharding.Sharding]:
|
||||
"""Flat sequence of output shardings.
|
||||
|
||||
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||||
@ -218,11 +218,11 @@ class XlaExecutable(Executable):
|
||||
def call(self, *args_flat) -> Sequence[Any]:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||||
def input_shardings(self) -> Sequence[jax.sharding.Sharding]:
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no input sharding information")
|
||||
|
||||
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||||
def output_shardings(self) -> Sequence[jax.sharding.Sharding]:
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no output sharding information")
|
||||
|
||||
@ -503,7 +503,7 @@ class Compiled(Stage):
|
||||
return self._executable.runtime_executable()
|
||||
|
||||
@property
|
||||
def input_shardings(self): # PyTree[sharding.XLACompatibleSharding]
|
||||
def input_shardings(self): # PyTree[sharding.Sharding]
|
||||
shardings_flat = self._executable.input_shardings()
|
||||
# Some input shardings got DCE'd
|
||||
if self.in_tree.num_leaves > len(shardings_flat):
|
||||
@ -513,7 +513,7 @@ class Compiled(Stage):
|
||||
return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error
|
||||
|
||||
@property
|
||||
def output_shardings(self): # PyTree[sharding.XLACompatibleSharding]
|
||||
def output_shardings(self): # PyTree[sharding.Sharding]
|
||||
shardings_flat = self._executable.output_shardings()
|
||||
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
|
||||
|
||||
|
@ -70,7 +70,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
async def create_async_array_from_callback(
|
||||
global_shape: array.Shape,
|
||||
inp_sharding: sharding_impls.XLACompatibleSharding,
|
||||
inp_sharding: jax.sharding.Sharding,
|
||||
data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
|
||||
):
|
||||
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
|
||||
@ -310,7 +310,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
|
||||
|
||||
|
||||
async def async_deserialize(
|
||||
user_in_sharding: sharding_impls.XLACompatibleSharding | Layout,
|
||||
user_in_sharding: jax.sharding.Sharding | Layout,
|
||||
tensorstore_spec: ts.Spec | dict[str, Any],
|
||||
global_shape: Sequence[int] | None = None,
|
||||
dtype=None,
|
||||
@ -320,10 +320,10 @@ async def async_deserialize(
|
||||
):
|
||||
in_sharding = (user_in_sharding.sharding
|
||||
if isinstance(user_in_sharding, Layout) else user_in_sharding)
|
||||
if not isinstance(in_sharding, sharding_impls.XLACompatibleSharding):
|
||||
if not isinstance(in_sharding, jax.sharding.Sharding):
|
||||
raise ValueError(
|
||||
'sharding passed to deserialization should be specified, concrete and'
|
||||
f' an instance of `jax.XLACompatibleSharding`. Got {in_sharding}')
|
||||
f' an instance of `jax.sharding.Sharding`. Got {in_sharding}')
|
||||
dll = (user_in_sharding.device_local_layout
|
||||
if isinstance(user_in_sharding, Layout) else None)
|
||||
t = await ts.open(
|
||||
|
@ -133,10 +133,8 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape,
|
||||
|
||||
|
||||
def _to_hlo_sharding(sharding, num_dimensions):
|
||||
if not isinstance(sharding, jax.sharding.XLACompatibleSharding):
|
||||
raise ValueError(
|
||||
"Custom Partitioning rules must return XLACompatibleShardings."
|
||||
)
|
||||
if not isinstance(sharding, jax.sharding.Sharding):
|
||||
raise ValueError("Custom Partitioning rules must return Sharding.")
|
||||
return sharding._to_xla_hlo_sharding(num_dimensions)
|
||||
|
||||
|
||||
|
@ -60,7 +60,7 @@ zip = util.safe_zip
|
||||
DType = Any
|
||||
Shape = jax._src.core.Shape
|
||||
# The values of input and output sharding from the lowering.
|
||||
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
|
||||
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
|
||||
HloSharding = xla_client.HloSharding
|
||||
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
|
||||
@ -315,12 +315,12 @@ class Exported:
|
||||
|
||||
def xla_compatible_in_shardings(
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.XLACompatibleSharding | None]:
|
||||
"""Creates XLACompatibleShardings corresponding to self.in_shardings.
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
"""Creates Shardings corresponding to self.in_shardings.
|
||||
|
||||
The Exported object stores `in_shardings` as HloShardings, which are
|
||||
independent of a mesh or set of devices. This method constructs
|
||||
XLACompatibleSharding that can be used in JAX APIs such as `jax.jit` or
|
||||
Sharding that can be used in JAX APIs such as `jax.jit` or
|
||||
`jax.device_put`.
|
||||
|
||||
Example usage:
|
||||
@ -354,8 +354,8 @@ class Exported:
|
||||
|
||||
def xla_compatible_out_shardings(
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.XLACompatibleSharding | None]:
|
||||
"""Creates XLACompatibleShardings corresponding to self.out_shardings.
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
"""Creates Shardings corresponding to self.out_shardings.
|
||||
|
||||
See documentation for xla_compatible_in_shardings.
|
||||
"""
|
||||
@ -978,7 +978,7 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
|
||||
|
||||
def _hlo_sharding_to_xla_compatible_sharding(
|
||||
hlo_sharding: HloSharding | None,
|
||||
mesh: sharding.Mesh) -> sharding.XLACompatibleSharding | None:
|
||||
mesh: sharding.Mesh) -> sharding.Sharding | None:
|
||||
if hlo_sharding is None:
|
||||
return None
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(
|
||||
|
@ -3463,7 +3463,7 @@ def split_to_logical_devices(tensor: TfVal,
|
||||
|
||||
|
||||
def _xla_compatible_sharding_to_hlo_sharding(
|
||||
s: sharding.XLACompatibleSharding,
|
||||
s: sharding.Sharding,
|
||||
aval: core.ShapedArray) -> xla_client.HloSharding | None:
|
||||
if sharding_impls.is_unspecified(s):
|
||||
return None
|
||||
@ -3515,8 +3515,8 @@ def _shard_value(val: TfVal,
|
||||
|
||||
def _pjit(*args: TfVal,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings: Sequence[sharding.XLACompatibleSharding],
|
||||
out_shardings: Sequence[sharding.XLACompatibleSharding],
|
||||
in_shardings: Sequence[sharding.Sharding],
|
||||
out_shardings: Sequence[sharding.Sharding],
|
||||
in_layouts, out_layouts,
|
||||
resource_env: mesh.ResourceEnv,
|
||||
donated_invars,
|
||||
@ -3549,7 +3549,7 @@ tf_impl_with_avals[pjit.pjit_p] = _pjit
|
||||
|
||||
|
||||
def _pjit_sharding_constraint(arg: TfVal, *,
|
||||
sharding: sharding.XLACompatibleSharding,
|
||||
sharding: sharding.Sharding,
|
||||
resource_env: mesh.ResourceEnv,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray,
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
from jax._src.sharding import Sharding as Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
XLACompatibleSharding as XLACompatibleSharding,
|
||||
XLACompatibleSharding as _deprecated_XLACompatibleSharding,
|
||||
NamedSharding as NamedSharding,
|
||||
SingleDeviceSharding as SingleDeviceSharding,
|
||||
PmapSharding as PmapSharding,
|
||||
@ -28,3 +28,23 @@ from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
from jax._src.interpreters.pxla import Mesh as Mesh
|
||||
|
||||
_deprecations = {
|
||||
# Added Jun 4, 2024.
|
||||
"XLACompatibleSharding": (
|
||||
(
|
||||
"jax.sharding.XLACompatibleSharding is deprecated. Use"
|
||||
" jax.sharding.Sharding instead."
|
||||
),
|
||||
_deprecated_XLACompatibleSharding,
|
||||
)
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
XLACompatibleSharding = _deprecated_XLACompatibleSharding
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
@ -3157,7 +3157,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"jax.jit only supports `XLACompatibleSharding`s being passed to "
|
||||
"jax.jit only supports `Sharding`s being passed to "
|
||||
"in_shardings"):
|
||||
with mesh:
|
||||
jax.jit(lambda x: x, in_shardings=P('x'),
|
||||
|
Loading…
x
Reference in New Issue
Block a user