diff --git a/CHANGELOG.md b/CHANGELOG.md index a59f47b8e..f918ca834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list [Parallelism with JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html) tutorial to understand the new concepts. + * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints + are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. + `jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are + deprecated and will be removed in 3 months. ## jaxlib 0.4.0 diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 4cfc4633c..79e8f41a6 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -18,24 +18,21 @@ import functools from collections import Counter import operator as op from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set, - FrozenSet, Union, cast, TYPE_CHECKING) + FrozenSet, Union, cast) import jax from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc from jax.interpreters import mlir +from jax.interpreters import pxla import numpy as np -if TYPE_CHECKING: - from jax.interpreters import pxla - Shape = Tuple[int, ...] Device = xc.Device Index = Tuple[slice, ...] XLADeviceAssignment = Sequence[Device] - @use_cpp_class(xc.Sharding if xc._version >= 94 else None) class Sharding(metaclass=abc.ABCMeta): """Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out @@ -137,9 +134,6 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta): @functools.lru_cache(maxsize=4096) def shard_shape(self, global_shape: Shape) -> Shape: - # TODO(https://github.com/google/jax/issues/12016): Remove the local import - from jax.interpreters import pxla - op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape))) if pxla.is_op_sharding_replicated(op_sharding): return global_shape @@ -205,6 +199,42 @@ def _enable_cpp_named_sharding(): return None +class _UnconstrainedPartitionSingleton: + + def __str__(self): + return "UNCONSTRAINED" + + +# Unconstrained sentinel value for PartitionSpec, representing a dimension for +# which the user wants XLA to assign the best partitioning. +# TODO(yashkatariya): May rename to AUTO. +_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton() + + +class PartitionSpec(tuple): + """Tuple of integer specifying how a value should be partitioned. + + Each integer corresponds to how many ways a dimension is partitioned. We + create a separate class for this so JAX's pytree utilities can distinguish it + from a tuple that should be treated as a pytree. + """ + + # A sentinel value representing a dim is unconstrained. + UNCONSTRAINED = _UNCONSTRAINED_PARTITION + + def __init__(self, *partitions): + pass + + def __new__(cls, *partitions): + return tuple.__new__(PartitionSpec, partitions) + + def __repr__(self): + return "PartitionSpec%s" % tuple.__repr__(self) + + def __reduce__(self): + return (PartitionSpec, tuple(self)) + + @use_cpp_class(_enable_cpp_named_sharding()) class NamedSharding(XLACompatibleSharding): r"""NamedSharding is a way to express ``Sharding``\s using named axes. @@ -241,7 +271,7 @@ class NamedSharding(XLACompatibleSharding): @use_cpp_method def __init__( - self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None): + self, mesh: pxla.Mesh, spec: PartitionSpec, _parsed_pspec = None): self.mesh = mesh self.spec = spec @@ -309,7 +339,6 @@ class NamedSharding(XLACompatibleSharding): axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None ) -> xc.OpSharding: from jax.experimental.pjit import get_array_mapping - from jax.interpreters import pxla array_mapping = get_array_mapping(self._parsed_pspec) # TODO(yashkatariya): Move away from sharding spec in NamedSharding @@ -432,8 +461,6 @@ class PmapSharding(XLACompatibleSharding): shape: The shape of the input array. sharded_dim: Dimension the input array is sharded on. Defaults to 0. """ - from jax.interpreters import pxla - # The dtype doesn't matter here. Its only used for creating the # sharding_spec. aval = jax.ShapedArray(shape, np.int32) @@ -457,7 +484,6 @@ class PmapSharding(XLACompatibleSharding): @functools.lru_cache(maxsize=4096) def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: - from jax.interpreters import pxla indices = pxla.spec_to_indices(global_shape, self.sharding_spec) return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type] @@ -470,8 +496,6 @@ class PmapSharding(XLACompatibleSharding): @functools.lru_cache(maxsize=4096) def shard_shape(self, global_shape: Shape) -> Shape: - from jax.interpreters import pxla - sharded_dim = None for i, s in enumerate(self.sharding_spec.sharding): if isinstance(s, pxla.Unstacked): @@ -616,8 +640,6 @@ class OpShardingSharding(XLACompatibleSharding): return hash(xc.HloSharding.from_proto(self._op_sharding)) def __eq__(self, other): - from jax.interpreters import pxla - if not isinstance(other, OpShardingSharding): return False if id(self) == id(other): @@ -634,8 +656,6 @@ class OpShardingSharding(XLACompatibleSharding): return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})' def is_compatible_aval(self, aval_shape: Shape): - from jax.interpreters import pxla - num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding) if len(aval_shape) < len(num_ways_dim_sharded): raise ValueError( @@ -649,8 +669,6 @@ class OpShardingSharding(XLACompatibleSharding): @functools.lru_cache(maxsize=4096) def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: - from jax.interpreters import pxla - indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape, len(self._devices)) return dict(safe_zip(self._devices, indices)) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index abd0078f0..cb2ccef80 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -46,7 +46,6 @@ from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance) from jax.interpreters import mlir -from jax.experimental.maps import Mesh # This submodule includes private test utilities that are not exported to # jax.test_util. Functionality appearing here is for internal use only, and @@ -1000,7 +999,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: if len(local_devices) < size: raise unittest.SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore - with Mesh(mesh_devices, axis_names): + with jax.sharding.Mesh(mesh_devices, axis_names): yield def with_mesh_from_kwargs(f): @@ -1040,7 +1039,7 @@ def create_global_mesh(mesh_shape, axis_names): raise unittest.SkipTest(f"Test requires {size} global devices.") devices = sorted(api.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - global_mesh = Mesh(mesh_devices, axis_names) + global_mesh = jax.sharding.Mesh(mesh_devices, axis_names) return global_mesh diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 7d55b9de0..11589bac5 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in +# 3 months from `jax.experimental.PartitionSpec`. from jax.interpreters.pxla import PartitionSpec as PartitionSpec from jax.experimental.x64_context import ( enable_x64 as enable_x64, diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 9bead238e..d0d1f9238 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -111,6 +111,8 @@ class FrozenDict(abc.Mapping): AxisName = core.AxisName ResourceAxisName = AxisName # Different name just for documentation purposes +# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in +# 3 months from `jax.experimental.maps.Mesh`. Mesh = pxla.Mesh ResourceEnv = pxla.ResourceEnv EMPTY_ENV = pxla.EMPTY_ENV diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index f87b45e56..2ec9db524 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -68,9 +68,7 @@ from jax._src import util from jax._src import dispatch from jax._src import profiler from jax._src import stages -from jax._src.sharding import (PmapSharding, NamedSharding, OpShardingSharding, - SingleDeviceSharding, XLACompatibleSharding, - _get_replicated_op_sharding) +from jax._src import sharding as sharding_internal from jax._src.abstract_arrays import array_types from jax._src.config import config from jax._src.config import flags @@ -118,6 +116,7 @@ ShardingSpec = pmap_lib.ShardingSpec MeshAxisName = Any OpShardingType = Any +PartitionSpec = sharding_internal.PartitionSpec def sharding_spec_mesh_shape(self): sharded_axis_sizes = [] @@ -542,7 +541,7 @@ class OutputType(enum.Enum): def local_aval_to_result_handler( aval: core.AbstractValue, - sharding: XLACompatibleSharding, + sharding: sharding_internal.XLACompatibleSharding, indices: Optional[Tuple[Index, ...]], ) -> Callable[[List[xb.xla_client.Buffer]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -844,7 +843,7 @@ def _sda_sharding(self): has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding) if has_unstacked: devices = np.array([d.device() for d in self.device_buffers]) - return PmapSharding(devices, self.sharding_spec) + return sharding_internal.PmapSharding(devices, self.sharding_spec) raise NotImplementedError( 'SDAs that are the output of pjit/xmap do not have the sharding attribute ' 'implemented. If you are trying to pass the SDA to pjit/xmap, please ' @@ -1541,9 +1540,9 @@ class UnloadedPmapExecutable: compiled: Any backend: xb.XlaBackend local_input_avals: Sequence[jax.core.AbstractValue] - input_shardings: Sequence[XLACompatibleSharding] + input_shardings: Sequence[sharding_internal.XLACompatibleSharding] local_output_avals: Sequence[ShapedArray] - output_shardings: Sequence[XLACompatibleSharding] + output_shardings: Sequence[sharding_internal.XLACompatibleSharding] unordered_effects: List[core.Effect] ordered_effects: List[core.Effect] keepalive: Sequence[Any] @@ -1729,7 +1728,7 @@ class PmapExecutable(stages.XlaExecutable): def _get_pmap_sharding(devices, specs): - return [PmapSharding(devices, spec) for spec in specs] + return [sharding_internal.PmapSharding(devices, spec) for spec in specs] multi_host_supported_collectives: Set[core.Primitive] = set() @@ -1915,11 +1914,11 @@ class ResultsHandler: def _get_sharding_specs( - shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray] + shardings: Sequence[sharding_internal.XLACompatibleSharding], avals: Sequence[ShapedArray] ) -> Sequence[ShardingSpec]: - if all(isinstance(s, PmapSharding) for s in shardings): + if all(isinstance(s, sharding_internal.PmapSharding) for s in shardings): return [s.sharding_spec for s in shardings] # type: ignore - elif all(isinstance(s, NamedSharding) for s in shardings): + elif all(isinstance(s, sharding_internal.NamedSharding) for s in shardings): return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)( aval.ndim, _get_array_mapping(s.spec)) for aval, s in safe_zip(avals, shardings)] @@ -1930,7 +1929,7 @@ def _get_sharding_specs( def local_avals_to_results_handler( unmapped_local_out_avals: Sequence[ShapedArray], - local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler: + local_shardings: Sequence[sharding_internal.XLACompatibleSharding]) -> ResultsHandler: out_indices = [tuple(s.devices_indices_map(aval.shape).values()) for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)] handlers = [ @@ -1942,7 +1941,7 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], - shardings: Sequence[XLACompatibleSharding], + shardings: Sequence[sharding_internal.XLACompatibleSharding], committed: bool, are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler: if config.jax_parallel_functions_output_gda or config.jax_array: @@ -1954,10 +1953,10 @@ def global_avals_to_results_handler( return ResultsHandler(handlers, shardings, global_out_avals) else: # This path is taken when the outputs are SDAs. - assert all(isinstance(s, NamedSharding) for s in shardings) + assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings) local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval) for aval, s in safe_zip(global_out_avals, shardings)] - local_shardings = [NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore + local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore for s in shardings] return local_avals_to_results_handler(local_out_avals, local_shardings) @@ -2702,52 +2701,16 @@ TilingMethod = Union[TileVectorize, TileManual] def _check_if_any_auto( - shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource, - _UnspecifiedValue]]) -> bool: + shardings: Iterable[Union[sharding_internal.XLACompatibleSharding, + _AUTOAxisResource, _UnspecifiedValue]]) -> bool: for s in shardings: if _is_auto(s): return True return False -class _UnconstrainedPartitionSingleton: - - def __str__(self): - return "UNCONSTRAINED" - - -# Unconstrained sentinel value for PartitionSpec, representing a dimension for -# which the user wants XLA to assign the best partitioning. -# TODO(yashkatariya): May rename to AUTO. -_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton() - - -class PartitionSpec(tuple): - """Tuple of integer specifying how a value should be partitioned. - - Each integer corresponds to how many ways a dimension is partitioned. We - create a separate class for this so JAX's pytree utilities can distinguish it - from a tuple that should be treated as a pytree. - """ - - # A sentinel value representing a dim is unconstrained. - UNCONSTRAINED = _UNCONSTRAINED_PARTITION - - def __init__(self, *partitions): - pass - - def __new__(cls, *partitions): - return tuple.__new__(PartitionSpec, partitions) - - def __repr__(self): - return "PartitionSpec%s" % tuple.__repr__(self) - - def __reduce__(self): - return (PartitionSpec, tuple(self)) - - def _get_and_check_device_assignment( - shardings: Iterable[XLACompatibleSharding], + shardings: Iterable[sharding_internal.XLACompatibleSharding], devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]: from jax._src.api import local_devices @@ -2800,8 +2763,8 @@ def lower_sharding_computation( fun: lu.WrappedFun, api_name: str, fun_name: str, - in_shardings: Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], - out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue], + in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]], + out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue], donated_invars: Sequence[bool], global_in_avals: Sequence[core.ShapedArray], in_is_global: Sequence[bool], @@ -2847,7 +2810,7 @@ def lower_sharding_computation( jaxpr_sharding or any(not _is_unspecified(o) for o in out_shardings)) # type: ignore - in_shardings = tuple(OpShardingSharding.get_replicated(device_assignment) + in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment) if _is_unspecified(i) else i for i in in_shardings) log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG @@ -3019,8 +2982,8 @@ def lower_mesh_computation( api_name: str, fun_name: str, mesh: Mesh, - in_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource]], - out_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource, + in_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource]], + out_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource, _UnspecifiedValue]], donated_invars: Sequence[bool], spmd_lowering: bool, @@ -3241,8 +3204,8 @@ class MeshComputation(stages.XlaLowering): def _get_input_metadata( global_in_avals: Sequence[ShapedArray], - in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool] -) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]], + in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool] +) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]], Sequence[ShapedArray]]: avals, shardings = _get_normalized_avals_and_shardings( global_in_avals, in_shardings, in_is_global) @@ -3251,8 +3214,8 @@ def _get_input_metadata( def _get_normalized_avals_and_shardings( global_in_avals: Sequence[ShapedArray], - in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool] -) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]: + in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool] +) -> Tuple[Sequence[ShapedArray], Sequence[sharding_internal.XLACompatibleSharding]]: avals = [] shardings = [] @@ -3262,9 +3225,9 @@ def _get_normalized_avals_and_shardings( aval = gaval in_sharding = i else: - assert isinstance(i, NamedSharding) + assert isinstance(i, sharding_internal.NamedSharding) aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval) - in_sharding = NamedSharding(i.mesh.local_mesh, i.spec) + in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec) avals.append(aval) shardings.append(in_sharding) @@ -3272,7 +3235,7 @@ def _get_normalized_avals_and_shardings( def _get_input_indices( - avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding] + avals: Sequence[ShapedArray], shardings: Sequence[sharding_internal.XLACompatibleSharding] ) -> Sequence[Tuple[Optional[Index], ...]]: input_indices = [] @@ -3308,13 +3271,17 @@ def _get_op_sharding_shardings_from_executable( # just return SingleDeviceShardings since we know the computation is running # only on 1 device. if len(device_assignment) == 1: - return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)], - [SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)]) + return ([sharding_internal.SingleDeviceSharding(device_assignment[0]) + for _ in range(num_in_avals)], + [sharding_internal.SingleDeviceSharding(device_assignment[0]) + for _ in range(num_out_avals)]) in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable) - in_shardings_xla = [OpShardingSharding(device_assignment, i) for i in in_op_shardings] - out_shardings_xla = [OpShardingSharding(device_assignment, o) for o in out_op_shardings] + in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i) + for i in in_op_shardings] + out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o) + for o in out_op_shardings] # This condition happens when all the elements in the output tuple have the # same sharding, so XLA decides to run the `FusionTupleDeduplicator` to # put the sharding on ROOT instead of the tuple. @@ -3332,8 +3299,8 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh): from jax.experimental import pjit in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh) - return ([NamedSharding(mesh, i) for i in in_pspec], - [NamedSharding(mesh, o) for o in out_pspec]) + return ([sharding_internal.NamedSharding(mesh, i) for i in in_pspec], + [sharding_internal.NamedSharding(mesh, o) for o in out_pspec]) @dataclasses.dataclass @@ -3342,9 +3309,9 @@ class UnloadedMeshExecutable: device_assignment: Sequence[xc.Device] backend: xb.XlaBackend input_avals: Sequence[ShapedArray] - input_shardings: Sequence[XLACompatibleSharding] + input_shardings: Sequence[sharding_internal.XLACompatibleSharding] output_avals: Sequence[ShapedArray] - output_shardings: Sequence[XLACompatibleSharding] + output_shardings: Sequence[sharding_internal.XLACompatibleSharding] committed: bool are_out_shardings_from_xla: Sequence[bool] pmap_nreps: int @@ -3396,8 +3363,8 @@ class UnloadedMeshExecutable: mesh: Optional[Mesh], global_in_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray], - in_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]], - out_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource, + in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource]], + out_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource, _UnspecifiedValue]], spmd_lowering: bool, tuple_args: bool, @@ -3638,9 +3605,9 @@ class MeshExecutable(stages.XlaExecutable): def _out_shardings_for_trivial( jaxpr: core.Jaxpr, consts: Sequence[Any], - in_shardings: Sequence[XLACompatibleSharding], + in_shardings: Sequence[sharding_internal.XLACompatibleSharding], device_assignment: Sequence[xc.Device], - ) -> List[XLACompatibleSharding]: + ) -> List[sharding_internal.XLACompatibleSharding]: # For each jaxpr output, compute a Sharding by: # * if the output is a forwarded input, get the corresponding in_sharding; # * if the output is a constant Array, get its .sharding attribute; @@ -3648,9 +3615,9 @@ def _out_shardings_for_trivial( # a replicated sharding from jax._src import array - rep = OpShardingSharding( - device_assignment, _get_replicated_op_sharding()) - shardings: Dict[core.Var, XLACompatibleSharding] = {} + rep = sharding_internal.OpShardingSharding( + device_assignment, sharding_internal._get_replicated_op_sharding()) + shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {} for constvar, constval in zip(jaxpr.constvars, consts): if isinstance(constval, array.ArrayImpl): shardings[constvar] = constval.sharding @@ -3744,7 +3711,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr( @lru_cache() def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None): - return NamedSharding(mesh, pspec, parsed_pspec) + return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec) def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings): diff --git a/jax/sharding.py b/jax/sharding.py index 7a120edfd..a23e1ae90 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -19,8 +19,11 @@ from jax._src.sharding import ( MeshPspecSharding as MeshPspecSharding, # New name of MeshPspecSharding to match PositionalSharding below. NamedSharding as NamedSharding, + PartitionSpec as PartitionSpec, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, OpShardingSharding as OpShardingSharding, PositionalSharding as PositionalSharding, ) + +from jax.interpreters.pxla import Mesh as Mesh diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b8f1e734e..9e3e25245 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -36,9 +36,8 @@ from jax import stages from jax.errors import JAXTypeError from jax import lax from jax import prng -# TODO(skye): do we still wanna call this PartitionSpec? +from jax.sharding import PartitionSpec as P from jax.experimental import maps -from jax.experimental import PartitionSpec as P from jax.experimental.maps import xmap from jax.experimental import global_device_array from jax.experimental import multihost_utils @@ -328,7 +327,7 @@ class PJitTest(jtu.BufferDonationTestCase): def testTwoMeshAxisSharding(self): @partial(pjit, in_axis_resources=P(('x', 'y'),), - out_axis_resources=P(('x', 'y'),)) + out_axis_resources=jax.sharding.PartitionSpec(('x', 'y'),)) def f(x, y): return x @ y @@ -2210,7 +2209,7 @@ class ArrayPjitTest(jtu.JaxTestCase): [devices[4], devices[6]], [devices[7], devices[5]]]) shape = (8, 2) - mesh = maps.Mesh(mesh_devices, ('x', 'y')) + mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)