Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.

PiperOrigin-RevId: 492358238
This commit is contained in:
Yash Katariya 2022-12-01 19:28:02 -08:00 committed by jax authors
parent ed9519dadf
commit 934bc4e1b3
8 changed files with 104 additions and 110 deletions

View File

@ -28,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list
[Parallelism with
JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html)
tutorial to understand the new concepts.
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
## jaxlib 0.4.0

View File

@ -18,24 +18,21 @@ import functools
from collections import Counter
import operator as op
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast, TYPE_CHECKING)
FrozenSet, Union, cast)
import jax
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax.interpreters import mlir
from jax.interpreters import pxla
import numpy as np
if TYPE_CHECKING:
from jax.interpreters import pxla
Shape = Tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@use_cpp_class(xc.Sharding if xc._version >= 94 else None)
class Sharding(metaclass=abc.ABCMeta):
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
@ -137,9 +134,6 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
# TODO(https://github.com/google/jax/issues/12016): Remove the local import
from jax.interpreters import pxla
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
if pxla.is_op_sharding_replicated(op_sharding):
return global_shape
@ -205,6 +199,42 @@ def _enable_cpp_named_sharding():
return None
class _UnconstrainedPartitionSingleton:
def __str__(self):
return "UNCONSTRAINED"
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
class PartitionSpec(tuple):
"""Tuple of integer specifying how a value should be partitioned.
Each integer corresponds to how many ways a dimension is partitioned. We
create a separate class for this so JAX's pytree utilities can distinguish it
from a tuple that should be treated as a pytree.
"""
# A sentinel value representing a dim is unconstrained.
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
def __init__(self, *partitions):
pass
def __new__(cls, *partitions):
return tuple.__new__(PartitionSpec, partitions)
def __repr__(self):
return "PartitionSpec%s" % tuple.__repr__(self)
def __reduce__(self):
return (PartitionSpec, tuple(self))
@use_cpp_class(_enable_cpp_named_sharding())
class NamedSharding(XLACompatibleSharding):
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
@ -241,7 +271,7 @@ class NamedSharding(XLACompatibleSharding):
@use_cpp_method
def __init__(
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
self, mesh: pxla.Mesh, spec: PartitionSpec, _parsed_pspec = None):
self.mesh = mesh
self.spec = spec
@ -309,7 +339,6 @@ class NamedSharding(XLACompatibleSharding):
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
) -> xc.OpSharding:
from jax.experimental.pjit import get_array_mapping
from jax.interpreters import pxla
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
@ -432,8 +461,6 @@ class PmapSharding(XLACompatibleSharding):
shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
"""
from jax.interpreters import pxla
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
aval = jax.ShapedArray(shape, np.int32)
@ -457,7 +484,6 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
from jax.interpreters import pxla
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@ -470,8 +496,6 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
from jax.interpreters import pxla
sharded_dim = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
@ -616,8 +640,6 @@ class OpShardingSharding(XLACompatibleSharding):
return hash(xc.HloSharding.from_proto(self._op_sharding))
def __eq__(self, other):
from jax.interpreters import pxla
if not isinstance(other, OpShardingSharding):
return False
if id(self) == id(other):
@ -634,8 +656,6 @@ class OpShardingSharding(XLACompatibleSharding):
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
def is_compatible_aval(self, aval_shape: Shape):
from jax.interpreters import pxla
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError(
@ -649,8 +669,6 @@ class OpShardingSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
from jax.interpreters import pxla
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices))
return dict(safe_zip(self._devices, indices))

View File

@ -46,7 +46,6 @@ from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
from jax.interpreters import mlir
from jax.experimental.maps import Mesh
# This submodule includes private test utilities that are not exported to
# jax.test_util. Functionality appearing here is for internal use only, and
@ -1000,7 +999,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
with Mesh(mesh_devices, axis_names):
with jax.sharding.Mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
@ -1040,7 +1039,7 @@ def create_global_mesh(mesh_shape, axis_names):
raise unittest.SkipTest(f"Test requires {size} global devices.")
devices = sorted(api.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
return global_mesh

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in
# 3 months from `jax.experimental.PartitionSpec`.
from jax.interpreters.pxla import PartitionSpec as PartitionSpec
from jax.experimental.x64_context import (
enable_x64 as enable_x64,

View File

@ -111,6 +111,8 @@ class FrozenDict(abc.Mapping):
AxisName = core.AxisName
ResourceAxisName = AxisName # Different name just for documentation purposes
# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in
# 3 months from `jax.experimental.maps.Mesh`.
Mesh = pxla.Mesh
ResourceEnv = pxla.ResourceEnv
EMPTY_ENV = pxla.EMPTY_ENV

View File

@ -68,9 +68,7 @@ from jax._src import util
from jax._src import dispatch
from jax._src import profiler
from jax._src import stages
from jax._src.sharding import (PmapSharding, NamedSharding, OpShardingSharding,
SingleDeviceSharding, XLACompatibleSharding,
_get_replicated_op_sharding)
from jax._src import sharding as sharding_internal
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.config import flags
@ -118,6 +116,7 @@ ShardingSpec = pmap_lib.ShardingSpec
MeshAxisName = Any
OpShardingType = Any
PartitionSpec = sharding_internal.PartitionSpec
def sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
@ -542,7 +541,7 @@ class OutputType(enum.Enum):
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: XLACompatibleSharding,
sharding: sharding_internal.XLACompatibleSharding,
indices: Optional[Tuple[Index, ...]],
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
@ -844,7 +843,7 @@ def _sda_sharding(self):
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
if has_unstacked:
devices = np.array([d.device() for d in self.device_buffers])
return PmapSharding(devices, self.sharding_spec)
return sharding_internal.PmapSharding(devices, self.sharding_spec)
raise NotImplementedError(
'SDAs that are the output of pjit/xmap do not have the sharding attribute '
'implemented. If you are trying to pass the SDA to pjit/xmap, please '
@ -1541,9 +1540,9 @@ class UnloadedPmapExecutable:
compiled: Any
backend: xb.XlaBackend
local_input_avals: Sequence[jax.core.AbstractValue]
input_shardings: Sequence[XLACompatibleSharding]
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[XLACompatibleSharding]
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
keepalive: Sequence[Any]
@ -1729,7 +1728,7 @@ class PmapExecutable(stages.XlaExecutable):
def _get_pmap_sharding(devices, specs):
return [PmapSharding(devices, spec) for spec in specs]
return [sharding_internal.PmapSharding(devices, spec) for spec in specs]
multi_host_supported_collectives: Set[core.Primitive] = set()
@ -1915,11 +1914,11 @@ class ResultsHandler:
def _get_sharding_specs(
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
shardings: Sequence[sharding_internal.XLACompatibleSharding], avals: Sequence[ShapedArray]
) -> Sequence[ShardingSpec]:
if all(isinstance(s, PmapSharding) for s in shardings):
if all(isinstance(s, sharding_internal.PmapSharding) for s in shardings):
return [s.sharding_spec for s in shardings] # type: ignore
elif all(isinstance(s, NamedSharding) for s in shardings):
elif all(isinstance(s, sharding_internal.NamedSharding) for s in shardings):
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
aval.ndim, _get_array_mapping(s.spec))
for aval, s in safe_zip(avals, shardings)]
@ -1930,7 +1929,7 @@ def _get_sharding_specs(
def local_avals_to_results_handler(
unmapped_local_out_avals: Sequence[ShapedArray],
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
local_shardings: Sequence[sharding_internal.XLACompatibleSharding]) -> ResultsHandler:
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
handlers = [
@ -1942,7 +1941,7 @@ def local_avals_to_results_handler(
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[XLACompatibleSharding],
shardings: Sequence[sharding_internal.XLACompatibleSharding],
committed: bool,
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
if config.jax_parallel_functions_output_gda or config.jax_array:
@ -1954,10 +1953,10 @@ def global_avals_to_results_handler(
return ResultsHandler(handlers, shardings, global_out_avals)
else:
# This path is taken when the outputs are SDAs.
assert all(isinstance(s, NamedSharding) for s in shardings)
assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings)
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
for aval, s in safe_zip(global_out_avals, shardings)]
local_shardings = [NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
for s in shardings]
return local_avals_to_results_handler(local_out_avals, local_shardings)
@ -2702,52 +2701,16 @@ TilingMethod = Union[TileVectorize, TileManual]
def _check_if_any_auto(
shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource,
_UnspecifiedValue]]) -> bool:
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
_AUTOAxisResource, _UnspecifiedValue]]) -> bool:
for s in shardings:
if _is_auto(s):
return True
return False
class _UnconstrainedPartitionSingleton:
def __str__(self):
return "UNCONSTRAINED"
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
class PartitionSpec(tuple):
"""Tuple of integer specifying how a value should be partitioned.
Each integer corresponds to how many ways a dimension is partitioned. We
create a separate class for this so JAX's pytree utilities can distinguish it
from a tuple that should be treated as a pytree.
"""
# A sentinel value representing a dim is unconstrained.
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
def __init__(self, *partitions):
pass
def __new__(cls, *partitions):
return tuple.__new__(PartitionSpec, partitions)
def __repr__(self):
return "PartitionSpec%s" % tuple.__repr__(self)
def __reduce__(self):
return (PartitionSpec, tuple(self))
def _get_and_check_device_assignment(
shardings: Iterable[XLACompatibleSharding],
shardings: Iterable[sharding_internal.XLACompatibleSharding],
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
from jax._src.api import local_devices
@ -2800,8 +2763,8 @@ def lower_sharding_computation(
fun: lu.WrappedFun,
api_name: str,
fun_name: str,
in_shardings: Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]],
out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]],
out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
in_is_global: Sequence[bool],
@ -2847,7 +2810,7 @@ def lower_sharding_computation(
jaxpr_sharding or
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
in_shardings = tuple(OpShardingSharding.get_replicated(device_assignment)
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
@ -3019,8 +2982,8 @@ def lower_mesh_computation(
api_name: str,
fun_name: str,
mesh: Mesh,
in_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource]],
out_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource,
in_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource]],
out_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource,
_UnspecifiedValue]],
donated_invars: Sequence[bool],
spmd_lowering: bool,
@ -3241,8 +3204,8 @@ class MeshComputation(stages.XlaLowering):
def _get_input_metadata(
global_in_avals: Sequence[ShapedArray],
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
Sequence[ShapedArray]]:
avals, shardings = _get_normalized_avals_and_shardings(
global_in_avals, in_shardings, in_is_global)
@ -3251,8 +3214,8 @@ def _get_input_metadata(
def _get_normalized_avals_and_shardings(
global_in_avals: Sequence[ShapedArray],
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]:
in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[ShapedArray], Sequence[sharding_internal.XLACompatibleSharding]]:
avals = []
shardings = []
@ -3262,9 +3225,9 @@ def _get_normalized_avals_and_shardings(
aval = gaval
in_sharding = i
else:
assert isinstance(i, NamedSharding)
assert isinstance(i, sharding_internal.NamedSharding)
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
in_sharding = NamedSharding(i.mesh.local_mesh, i.spec)
in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
avals.append(aval)
shardings.append(in_sharding)
@ -3272,7 +3235,7 @@ def _get_normalized_avals_and_shardings(
def _get_input_indices(
avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding]
avals: Sequence[ShapedArray], shardings: Sequence[sharding_internal.XLACompatibleSharding]
) -> Sequence[Tuple[Optional[Index], ...]]:
input_indices = []
@ -3308,13 +3271,17 @@ def _get_op_sharding_shardings_from_executable(
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
if len(device_assignment) == 1:
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])
return ([sharding_internal.SingleDeviceSharding(device_assignment[0])
for _ in range(num_in_avals)],
[sharding_internal.SingleDeviceSharding(device_assignment[0])
for _ in range(num_out_avals)])
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [OpShardingSharding(device_assignment, i) for i in in_op_shardings]
out_shardings_xla = [OpShardingSharding(device_assignment, o) for o in out_op_shardings]
in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i)
for i in in_op_shardings]
out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o)
for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple.
@ -3332,8 +3299,8 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
from jax.experimental import pjit
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
return ([NamedSharding(mesh, i) for i in in_pspec],
[NamedSharding(mesh, o) for o in out_pspec])
return ([sharding_internal.NamedSharding(mesh, i) for i in in_pspec],
[sharding_internal.NamedSharding(mesh, o) for o in out_pspec])
@dataclasses.dataclass
@ -3342,9 +3309,9 @@ class UnloadedMeshExecutable:
device_assignment: Sequence[xc.Device]
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[XLACompatibleSharding]
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
output_avals: Sequence[ShapedArray]
output_shardings: Sequence[XLACompatibleSharding]
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
committed: bool
are_out_shardings_from_xla: Sequence[bool]
pmap_nreps: int
@ -3396,8 +3363,8 @@ class UnloadedMeshExecutable:
mesh: Optional[Mesh],
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]],
out_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource,
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource]],
out_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource,
_UnspecifiedValue]],
spmd_lowering: bool,
tuple_args: bool,
@ -3638,9 +3605,9 @@ class MeshExecutable(stages.XlaExecutable):
def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[XLACompatibleSharding],
in_shardings: Sequence[sharding_internal.XLACompatibleSharding],
device_assignment: Sequence[xc.Device],
) -> List[XLACompatibleSharding]:
) -> List[sharding_internal.XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute;
@ -3648,9 +3615,9 @@ def _out_shardings_for_trivial(
# a replicated sharding
from jax._src import array
rep = OpShardingSharding(
device_assignment, _get_replicated_op_sharding())
shardings: Dict[core.Var, XLACompatibleSharding] = {}
rep = sharding_internal.OpShardingSharding(
device_assignment, sharding_internal._get_replicated_op_sharding())
shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding
@ -3744,7 +3711,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@lru_cache()
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
return NamedSharding(mesh, pspec, parsed_pspec)
return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec)
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):

View File

@ -19,8 +19,11 @@ from jax._src.sharding import (
MeshPspecSharding as MeshPspecSharding,
# New name of MeshPspecSharding to match PositionalSharding below.
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
OpShardingSharding as OpShardingSharding,
PositionalSharding as PositionalSharding,
)
from jax.interpreters.pxla import Mesh as Mesh

View File

@ -36,9 +36,8 @@ from jax import stages
from jax.errors import JAXTypeError
from jax import lax
from jax import prng
# TODO(skye): do we still wanna call this PartitionSpec?
from jax.sharding import PartitionSpec as P
from jax.experimental import maps
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap
from jax.experimental import global_device_array
from jax.experimental import multihost_utils
@ -328,7 +327,7 @@ class PJitTest(jtu.BufferDonationTestCase):
def testTwoMeshAxisSharding(self):
@partial(pjit,
in_axis_resources=P(('x', 'y'),),
out_axis_resources=P(('x', 'y'),))
out_axis_resources=jax.sharding.PartitionSpec(('x', 'y'),))
def f(x, y):
return x @ y
@ -2210,7 +2209,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
[devices[4], devices[6]],
[devices[7], devices[5]]])
shape = (8, 2)
mesh = maps.Mesh(mesh_devices, ('x', 'y'))
mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)