Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.

PiperOrigin-RevId: 492358238
This commit is contained in:
Yash Katariya 2022-12-01 19:28:02 -08:00 committed by jax authors
parent ed9519dadf
commit 934bc4e1b3
8 changed files with 104 additions and 110 deletions

View File

@ -28,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list
[Parallelism with [Parallelism with
JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html) JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html)
tutorial to understand the new concepts. tutorial to understand the new concepts.
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
## jaxlib 0.4.0 ## jaxlib 0.4.0

View File

@ -18,24 +18,21 @@ import functools
from collections import Counter from collections import Counter
import operator as op import operator as op
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set, from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast, TYPE_CHECKING) FrozenSet, Union, cast)
import jax import jax
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters import pxla
import numpy as np import numpy as np
if TYPE_CHECKING:
from jax.interpreters import pxla
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
Device = xc.Device Device = xc.Device
Index = Tuple[slice, ...] Index = Tuple[slice, ...]
XLADeviceAssignment = Sequence[Device] XLADeviceAssignment = Sequence[Device]
@use_cpp_class(xc.Sharding if xc._version >= 94 else None) @use_cpp_class(xc.Sharding if xc._version >= 94 else None)
class Sharding(metaclass=abc.ABCMeta): class Sharding(metaclass=abc.ABCMeta):
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out """Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
@ -137,9 +134,6 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape: def shard_shape(self, global_shape: Shape) -> Shape:
# TODO(https://github.com/google/jax/issues/12016): Remove the local import
from jax.interpreters import pxla
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape))) op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
if pxla.is_op_sharding_replicated(op_sharding): if pxla.is_op_sharding_replicated(op_sharding):
return global_shape return global_shape
@ -205,6 +199,42 @@ def _enable_cpp_named_sharding():
return None return None
class _UnconstrainedPartitionSingleton:
def __str__(self):
return "UNCONSTRAINED"
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
class PartitionSpec(tuple):
"""Tuple of integer specifying how a value should be partitioned.
Each integer corresponds to how many ways a dimension is partitioned. We
create a separate class for this so JAX's pytree utilities can distinguish it
from a tuple that should be treated as a pytree.
"""
# A sentinel value representing a dim is unconstrained.
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
def __init__(self, *partitions):
pass
def __new__(cls, *partitions):
return tuple.__new__(PartitionSpec, partitions)
def __repr__(self):
return "PartitionSpec%s" % tuple.__repr__(self)
def __reduce__(self):
return (PartitionSpec, tuple(self))
@use_cpp_class(_enable_cpp_named_sharding()) @use_cpp_class(_enable_cpp_named_sharding())
class NamedSharding(XLACompatibleSharding): class NamedSharding(XLACompatibleSharding):
r"""NamedSharding is a way to express ``Sharding``\s using named axes. r"""NamedSharding is a way to express ``Sharding``\s using named axes.
@ -241,7 +271,7 @@ class NamedSharding(XLACompatibleSharding):
@use_cpp_method @use_cpp_method
def __init__( def __init__(
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None): self, mesh: pxla.Mesh, spec: PartitionSpec, _parsed_pspec = None):
self.mesh = mesh self.mesh = mesh
self.spec = spec self.spec = spec
@ -309,7 +339,6 @@ class NamedSharding(XLACompatibleSharding):
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
) -> xc.OpSharding: ) -> xc.OpSharding:
from jax.experimental.pjit import get_array_mapping from jax.experimental.pjit import get_array_mapping
from jax.interpreters import pxla
array_mapping = get_array_mapping(self._parsed_pspec) array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in NamedSharding # TODO(yashkatariya): Move away from sharding spec in NamedSharding
@ -432,8 +461,6 @@ class PmapSharding(XLACompatibleSharding):
shape: The shape of the input array. shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0. sharded_dim: Dimension the input array is sharded on. Defaults to 0.
""" """
from jax.interpreters import pxla
# The dtype doesn't matter here. Its only used for creating the # The dtype doesn't matter here. Its only used for creating the
# sharding_spec. # sharding_spec.
aval = jax.ShapedArray(shape, np.int32) aval = jax.ShapedArray(shape, np.int32)
@ -457,7 +484,6 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
from jax.interpreters import pxla
indices = pxla.spec_to_indices(global_shape, self.sharding_spec) indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type] return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@ -470,8 +496,6 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape: def shard_shape(self, global_shape: Shape) -> Shape:
from jax.interpreters import pxla
sharded_dim = None sharded_dim = None
for i, s in enumerate(self.sharding_spec.sharding): for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked): if isinstance(s, pxla.Unstacked):
@ -616,8 +640,6 @@ class OpShardingSharding(XLACompatibleSharding):
return hash(xc.HloSharding.from_proto(self._op_sharding)) return hash(xc.HloSharding.from_proto(self._op_sharding))
def __eq__(self, other): def __eq__(self, other):
from jax.interpreters import pxla
if not isinstance(other, OpShardingSharding): if not isinstance(other, OpShardingSharding):
return False return False
if id(self) == id(other): if id(self) == id(other):
@ -634,8 +656,6 @@ class OpShardingSharding(XLACompatibleSharding):
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})' return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
def is_compatible_aval(self, aval_shape: Shape): def is_compatible_aval(self, aval_shape: Shape):
from jax.interpreters import pxla
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding) num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
if len(aval_shape) < len(num_ways_dim_sharded): if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError( raise ValueError(
@ -649,8 +669,6 @@ class OpShardingSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
from jax.interpreters import pxla
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape, indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices)) len(self._devices))
return dict(safe_zip(self._devices, indices)) return dict(safe_zip(self._devices, indices))

View File

@ -46,7 +46,6 @@ from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance) check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.experimental.maps import Mesh
# This submodule includes private test utilities that are not exported to # This submodule includes private test utilities that are not exported to
# jax.test_util. Functionality appearing here is for internal use only, and # jax.test_util. Functionality appearing here is for internal use only, and
@ -1000,7 +999,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
if len(local_devices) < size: if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices") raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
with Mesh(mesh_devices, axis_names): with jax.sharding.Mesh(mesh_devices, axis_names):
yield yield
def with_mesh_from_kwargs(f): def with_mesh_from_kwargs(f):
@ -1040,7 +1039,7 @@ def create_global_mesh(mesh_shape, axis_names):
raise unittest.SkipTest(f"Test requires {size} global devices.") raise unittest.SkipTest(f"Test requires {size} global devices.")
devices = sorted(api.devices(), key=lambda d: d.id) devices = sorted(api.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(mesh_shape) mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names) global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
return global_mesh return global_mesh

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in
# 3 months from `jax.experimental.PartitionSpec`.
from jax.interpreters.pxla import PartitionSpec as PartitionSpec from jax.interpreters.pxla import PartitionSpec as PartitionSpec
from jax.experimental.x64_context import ( from jax.experimental.x64_context import (
enable_x64 as enable_x64, enable_x64 as enable_x64,

View File

@ -111,6 +111,8 @@ class FrozenDict(abc.Mapping):
AxisName = core.AxisName AxisName = core.AxisName
ResourceAxisName = AxisName # Different name just for documentation purposes ResourceAxisName = AxisName # Different name just for documentation purposes
# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in
# 3 months from `jax.experimental.maps.Mesh`.
Mesh = pxla.Mesh Mesh = pxla.Mesh
ResourceEnv = pxla.ResourceEnv ResourceEnv = pxla.ResourceEnv
EMPTY_ENV = pxla.EMPTY_ENV EMPTY_ENV = pxla.EMPTY_ENV

View File

@ -68,9 +68,7 @@ from jax._src import util
from jax._src import dispatch from jax._src import dispatch
from jax._src import profiler from jax._src import profiler
from jax._src import stages from jax._src import stages
from jax._src.sharding import (PmapSharding, NamedSharding, OpShardingSharding, from jax._src import sharding as sharding_internal
SingleDeviceSharding, XLACompatibleSharding,
_get_replicated_op_sharding)
from jax._src.abstract_arrays import array_types from jax._src.abstract_arrays import array_types
from jax._src.config import config from jax._src.config import config
from jax._src.config import flags from jax._src.config import flags
@ -118,6 +116,7 @@ ShardingSpec = pmap_lib.ShardingSpec
MeshAxisName = Any MeshAxisName = Any
OpShardingType = Any OpShardingType = Any
PartitionSpec = sharding_internal.PartitionSpec
def sharding_spec_mesh_shape(self): def sharding_spec_mesh_shape(self):
sharded_axis_sizes = [] sharded_axis_sizes = []
@ -542,7 +541,7 @@ class OutputType(enum.Enum):
def local_aval_to_result_handler( def local_aval_to_result_handler(
aval: core.AbstractValue, aval: core.AbstractValue,
sharding: XLACompatibleSharding, sharding: sharding_internal.XLACompatibleSharding,
indices: Optional[Tuple[Index, ...]], indices: Optional[Tuple[Index, ...]],
) -> Callable[[List[xb.xla_client.Buffer]], Any]: ) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval. """Returns a function for handling the raw buffers of a single output aval.
@ -844,7 +843,7 @@ def _sda_sharding(self):
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding) has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
if has_unstacked: if has_unstacked:
devices = np.array([d.device() for d in self.device_buffers]) devices = np.array([d.device() for d in self.device_buffers])
return PmapSharding(devices, self.sharding_spec) return sharding_internal.PmapSharding(devices, self.sharding_spec)
raise NotImplementedError( raise NotImplementedError(
'SDAs that are the output of pjit/xmap do not have the sharding attribute ' 'SDAs that are the output of pjit/xmap do not have the sharding attribute '
'implemented. If you are trying to pass the SDA to pjit/xmap, please ' 'implemented. If you are trying to pass the SDA to pjit/xmap, please '
@ -1541,9 +1540,9 @@ class UnloadedPmapExecutable:
compiled: Any compiled: Any
backend: xb.XlaBackend backend: xb.XlaBackend
local_input_avals: Sequence[jax.core.AbstractValue] local_input_avals: Sequence[jax.core.AbstractValue]
input_shardings: Sequence[XLACompatibleSharding] input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
local_output_avals: Sequence[ShapedArray] local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[XLACompatibleSharding] output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
unordered_effects: List[core.Effect] unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect] ordered_effects: List[core.Effect]
keepalive: Sequence[Any] keepalive: Sequence[Any]
@ -1729,7 +1728,7 @@ class PmapExecutable(stages.XlaExecutable):
def _get_pmap_sharding(devices, specs): def _get_pmap_sharding(devices, specs):
return [PmapSharding(devices, spec) for spec in specs] return [sharding_internal.PmapSharding(devices, spec) for spec in specs]
multi_host_supported_collectives: Set[core.Primitive] = set() multi_host_supported_collectives: Set[core.Primitive] = set()
@ -1915,11 +1914,11 @@ class ResultsHandler:
def _get_sharding_specs( def _get_sharding_specs(
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray] shardings: Sequence[sharding_internal.XLACompatibleSharding], avals: Sequence[ShapedArray]
) -> Sequence[ShardingSpec]: ) -> Sequence[ShardingSpec]:
if all(isinstance(s, PmapSharding) for s in shardings): if all(isinstance(s, sharding_internal.PmapSharding) for s in shardings):
return [s.sharding_spec for s in shardings] # type: ignore return [s.sharding_spec for s in shardings] # type: ignore
elif all(isinstance(s, NamedSharding) for s in shardings): elif all(isinstance(s, sharding_internal.NamedSharding) for s in shardings):
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)( return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
aval.ndim, _get_array_mapping(s.spec)) aval.ndim, _get_array_mapping(s.spec))
for aval, s in safe_zip(avals, shardings)] for aval, s in safe_zip(avals, shardings)]
@ -1930,7 +1929,7 @@ def _get_sharding_specs(
def local_avals_to_results_handler( def local_avals_to_results_handler(
unmapped_local_out_avals: Sequence[ShapedArray], unmapped_local_out_avals: Sequence[ShapedArray],
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler: local_shardings: Sequence[sharding_internal.XLACompatibleSharding]) -> ResultsHandler:
out_indices = [tuple(s.devices_indices_map(aval.shape).values()) out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)] for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
handlers = [ handlers = [
@ -1942,7 +1941,7 @@ def local_avals_to_results_handler(
def global_avals_to_results_handler( def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray],
shardings: Sequence[XLACompatibleSharding], shardings: Sequence[sharding_internal.XLACompatibleSharding],
committed: bool, committed: bool,
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler: are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
if config.jax_parallel_functions_output_gda or config.jax_array: if config.jax_parallel_functions_output_gda or config.jax_array:
@ -1954,10 +1953,10 @@ def global_avals_to_results_handler(
return ResultsHandler(handlers, shardings, global_out_avals) return ResultsHandler(handlers, shardings, global_out_avals)
else: else:
# This path is taken when the outputs are SDAs. # This path is taken when the outputs are SDAs.
assert all(isinstance(s, NamedSharding) for s in shardings) assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings)
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval) local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
for aval, s in safe_zip(global_out_avals, shardings)] for aval, s in safe_zip(global_out_avals, shardings)]
local_shardings = [NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
for s in shardings] for s in shardings]
return local_avals_to_results_handler(local_out_avals, local_shardings) return local_avals_to_results_handler(local_out_avals, local_shardings)
@ -2702,52 +2701,16 @@ TilingMethod = Union[TileVectorize, TileManual]
def _check_if_any_auto( def _check_if_any_auto(
shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource, shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
_UnspecifiedValue]]) -> bool: _AUTOAxisResource, _UnspecifiedValue]]) -> bool:
for s in shardings: for s in shardings:
if _is_auto(s): if _is_auto(s):
return True return True
return False return False
class _UnconstrainedPartitionSingleton:
def __str__(self):
return "UNCONSTRAINED"
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
class PartitionSpec(tuple):
"""Tuple of integer specifying how a value should be partitioned.
Each integer corresponds to how many ways a dimension is partitioned. We
create a separate class for this so JAX's pytree utilities can distinguish it
from a tuple that should be treated as a pytree.
"""
# A sentinel value representing a dim is unconstrained.
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
def __init__(self, *partitions):
pass
def __new__(cls, *partitions):
return tuple.__new__(PartitionSpec, partitions)
def __repr__(self):
return "PartitionSpec%s" % tuple.__repr__(self)
def __reduce__(self):
return (PartitionSpec, tuple(self))
def _get_and_check_device_assignment( def _get_and_check_device_assignment(
shardings: Iterable[XLACompatibleSharding], shardings: Iterable[sharding_internal.XLACompatibleSharding],
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]: devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
from jax._src.api import local_devices from jax._src.api import local_devices
@ -2800,8 +2763,8 @@ def lower_sharding_computation(
fun: lu.WrappedFun, fun: lu.WrappedFun,
api_name: str, api_name: str,
fun_name: str, fun_name: str,
in_shardings: Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]],
out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue], out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray], global_in_avals: Sequence[core.ShapedArray],
in_is_global: Sequence[bool], in_is_global: Sequence[bool],
@ -2847,7 +2810,7 @@ def lower_sharding_computation(
jaxpr_sharding or jaxpr_sharding or
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
in_shardings = tuple(OpShardingSharding.get_replicated(device_assignment) in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings) if _is_unspecified(i) else i for i in in_shardings)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
@ -3019,8 +2982,8 @@ def lower_mesh_computation(
api_name: str, api_name: str,
fun_name: str, fun_name: str,
mesh: Mesh, mesh: Mesh,
in_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource]], in_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource]],
out_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource, out_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource,
_UnspecifiedValue]], _UnspecifiedValue]],
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
spmd_lowering: bool, spmd_lowering: bool,
@ -3241,8 +3204,8 @@ class MeshComputation(stages.XlaLowering):
def _get_input_metadata( def _get_input_metadata(
global_in_avals: Sequence[ShapedArray], global_in_avals: Sequence[ShapedArray],
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool] in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]], ) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
Sequence[ShapedArray]]: Sequence[ShapedArray]]:
avals, shardings = _get_normalized_avals_and_shardings( avals, shardings = _get_normalized_avals_and_shardings(
global_in_avals, in_shardings, in_is_global) global_in_avals, in_shardings, in_is_global)
@ -3251,8 +3214,8 @@ def _get_input_metadata(
def _get_normalized_avals_and_shardings( def _get_normalized_avals_and_shardings(
global_in_avals: Sequence[ShapedArray], global_in_avals: Sequence[ShapedArray],
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool] in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]: ) -> Tuple[Sequence[ShapedArray], Sequence[sharding_internal.XLACompatibleSharding]]:
avals = [] avals = []
shardings = [] shardings = []
@ -3262,9 +3225,9 @@ def _get_normalized_avals_and_shardings(
aval = gaval aval = gaval
in_sharding = i in_sharding = i
else: else:
assert isinstance(i, NamedSharding) assert isinstance(i, sharding_internal.NamedSharding)
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval) aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
in_sharding = NamedSharding(i.mesh.local_mesh, i.spec) in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
avals.append(aval) avals.append(aval)
shardings.append(in_sharding) shardings.append(in_sharding)
@ -3272,7 +3235,7 @@ def _get_normalized_avals_and_shardings(
def _get_input_indices( def _get_input_indices(
avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding] avals: Sequence[ShapedArray], shardings: Sequence[sharding_internal.XLACompatibleSharding]
) -> Sequence[Tuple[Optional[Index], ...]]: ) -> Sequence[Tuple[Optional[Index], ...]]:
input_indices = [] input_indices = []
@ -3308,13 +3271,17 @@ def _get_op_sharding_shardings_from_executable(
# just return SingleDeviceShardings since we know the computation is running # just return SingleDeviceShardings since we know the computation is running
# only on 1 device. # only on 1 device.
if len(device_assignment) == 1: if len(device_assignment) == 1:
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)], return ([sharding_internal.SingleDeviceSharding(device_assignment[0])
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)]) for _ in range(num_in_avals)],
[sharding_internal.SingleDeviceSharding(device_assignment[0])
for _ in range(num_out_avals)])
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable) in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [OpShardingSharding(device_assignment, i) for i in in_op_shardings] in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i)
out_shardings_xla = [OpShardingSharding(device_assignment, o) for o in out_op_shardings] for i in in_op_shardings]
out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o)
for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the # This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to # same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple. # put the sharding on ROOT instead of the tuple.
@ -3332,8 +3299,8 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
from jax.experimental import pjit from jax.experimental import pjit
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh) in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
return ([NamedSharding(mesh, i) for i in in_pspec], return ([sharding_internal.NamedSharding(mesh, i) for i in in_pspec],
[NamedSharding(mesh, o) for o in out_pspec]) [sharding_internal.NamedSharding(mesh, o) for o in out_pspec])
@dataclasses.dataclass @dataclasses.dataclass
@ -3342,9 +3309,9 @@ class UnloadedMeshExecutable:
device_assignment: Sequence[xc.Device] device_assignment: Sequence[xc.Device]
backend: xb.XlaBackend backend: xb.XlaBackend
input_avals: Sequence[ShapedArray] input_avals: Sequence[ShapedArray]
input_shardings: Sequence[XLACompatibleSharding] input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
output_avals: Sequence[ShapedArray] output_avals: Sequence[ShapedArray]
output_shardings: Sequence[XLACompatibleSharding] output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
committed: bool committed: bool
are_out_shardings_from_xla: Sequence[bool] are_out_shardings_from_xla: Sequence[bool]
pmap_nreps: int pmap_nreps: int
@ -3396,8 +3363,8 @@ class UnloadedMeshExecutable:
mesh: Optional[Mesh], mesh: Optional[Mesh],
global_in_avals: Sequence[ShapedArray], global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]], in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource]],
out_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource, out_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource,
_UnspecifiedValue]], _UnspecifiedValue]],
spmd_lowering: bool, spmd_lowering: bool,
tuple_args: bool, tuple_args: bool,
@ -3638,9 +3605,9 @@ class MeshExecutable(stages.XlaExecutable):
def _out_shardings_for_trivial( def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any], jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[XLACompatibleSharding], in_shardings: Sequence[sharding_internal.XLACompatibleSharding],
device_assignment: Sequence[xc.Device], device_assignment: Sequence[xc.Device],
) -> List[XLACompatibleSharding]: ) -> List[sharding_internal.XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by: # For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding; # * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute; # * if the output is a constant Array, get its .sharding attribute;
@ -3648,9 +3615,9 @@ def _out_shardings_for_trivial(
# a replicated sharding # a replicated sharding
from jax._src import array from jax._src import array
rep = OpShardingSharding( rep = sharding_internal.OpShardingSharding(
device_assignment, _get_replicated_op_sharding()) device_assignment, sharding_internal._get_replicated_op_sharding())
shardings: Dict[core.Var, XLACompatibleSharding] = {} shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts): for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl): if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding shardings[constvar] = constval.sharding
@ -3744,7 +3711,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@lru_cache() @lru_cache()
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None): def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
return NamedSharding(mesh, pspec, parsed_pspec) return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec)
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings): def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):

View File

@ -19,8 +19,11 @@ from jax._src.sharding import (
MeshPspecSharding as MeshPspecSharding, MeshPspecSharding as MeshPspecSharding,
# New name of MeshPspecSharding to match PositionalSharding below. # New name of MeshPspecSharding to match PositionalSharding below.
NamedSharding as NamedSharding, NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
SingleDeviceSharding as SingleDeviceSharding, SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding, PmapSharding as PmapSharding,
OpShardingSharding as OpShardingSharding, OpShardingSharding as OpShardingSharding,
PositionalSharding as PositionalSharding, PositionalSharding as PositionalSharding,
) )
from jax.interpreters.pxla import Mesh as Mesh

View File

@ -36,9 +36,8 @@ from jax import stages
from jax.errors import JAXTypeError from jax.errors import JAXTypeError
from jax import lax from jax import lax
from jax import prng from jax import prng
# TODO(skye): do we still wanna call this PartitionSpec? from jax.sharding import PartitionSpec as P
from jax.experimental import maps from jax.experimental import maps
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap from jax.experimental.maps import xmap
from jax.experimental import global_device_array from jax.experimental import global_device_array
from jax.experimental import multihost_utils from jax.experimental import multihost_utils
@ -328,7 +327,7 @@ class PJitTest(jtu.BufferDonationTestCase):
def testTwoMeshAxisSharding(self): def testTwoMeshAxisSharding(self):
@partial(pjit, @partial(pjit,
in_axis_resources=P(('x', 'y'),), in_axis_resources=P(('x', 'y'),),
out_axis_resources=P(('x', 'y'),)) out_axis_resources=jax.sharding.PartitionSpec(('x', 'y'),))
def f(x, y): def f(x, y):
return x @ y return x @ y
@ -2210,7 +2209,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
[devices[4], devices[6]], [devices[4], devices[6]],
[devices[7], devices[5]]]) [devices[7], devices[5]]])
shape = (8, 2) shape = (8, 2)
mesh = maps.Mesh(mesh_devices, ('x', 'y')) mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y')) s = NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape) inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)