mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move PartitionSpec
and Mesh
out of experimental and into the sharding
namespace. The new API endpoint is jax.sharding.PartitionSpec
and jax.sharding.Mesh
.
PiperOrigin-RevId: 492358238
This commit is contained in:
parent
ed9519dadf
commit
934bc4e1b3
@ -28,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
[Parallelism with
|
||||
JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html)
|
||||
tutorial to understand the new concepts.
|
||||
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
|
||||
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
|
||||
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
|
||||
deprecated and will be removed in 3 months.
|
||||
|
||||
|
||||
## jaxlib 0.4.0
|
||||
|
@ -18,24 +18,21 @@ import functools
|
||||
from collections import Counter
|
||||
import operator as op
|
||||
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast, TYPE_CHECKING)
|
||||
FrozenSet, Union, cast)
|
||||
|
||||
import jax
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jax.interpreters import pxla
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
|
||||
@use_cpp_class(xc.Sharding if xc._version >= 94 else None)
|
||||
class Sharding(metaclass=abc.ABCMeta):
|
||||
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
||||
@ -137,9 +134,6 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
# TODO(https://github.com/google/jax/issues/12016): Remove the local import
|
||||
from jax.interpreters import pxla
|
||||
|
||||
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
|
||||
if pxla.is_op_sharding_replicated(op_sharding):
|
||||
return global_shape
|
||||
@ -205,6 +199,42 @@ def _enable_cpp_named_sharding():
|
||||
return None
|
||||
|
||||
|
||||
class _UnconstrainedPartitionSingleton:
|
||||
|
||||
def __str__(self):
|
||||
return "UNCONSTRAINED"
|
||||
|
||||
|
||||
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
||||
# which the user wants XLA to assign the best partitioning.
|
||||
# TODO(yashkatariya): May rename to AUTO.
|
||||
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||
|
||||
|
||||
class PartitionSpec(tuple):
|
||||
"""Tuple of integer specifying how a value should be partitioned.
|
||||
|
||||
Each integer corresponds to how many ways a dimension is partitioned. We
|
||||
create a separate class for this so JAX's pytree utilities can distinguish it
|
||||
from a tuple that should be treated as a pytree.
|
||||
"""
|
||||
|
||||
# A sentinel value representing a dim is unconstrained.
|
||||
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
|
||||
|
||||
def __init__(self, *partitions):
|
||||
pass
|
||||
|
||||
def __new__(cls, *partitions):
|
||||
return tuple.__new__(PartitionSpec, partitions)
|
||||
|
||||
def __repr__(self):
|
||||
return "PartitionSpec%s" % tuple.__repr__(self)
|
||||
|
||||
def __reduce__(self):
|
||||
return (PartitionSpec, tuple(self))
|
||||
|
||||
|
||||
@use_cpp_class(_enable_cpp_named_sharding())
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
||||
@ -241,7 +271,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
|
||||
@use_cpp_method
|
||||
def __init__(
|
||||
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
|
||||
self, mesh: pxla.Mesh, spec: PartitionSpec, _parsed_pspec = None):
|
||||
|
||||
self.mesh = mesh
|
||||
self.spec = spec
|
||||
@ -309,7 +339,6 @@ class NamedSharding(XLACompatibleSharding):
|
||||
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
||||
) -> xc.OpSharding:
|
||||
from jax.experimental.pjit import get_array_mapping
|
||||
from jax.interpreters import pxla
|
||||
|
||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
||||
@ -432,8 +461,6 @@ class PmapSharding(XLACompatibleSharding):
|
||||
shape: The shape of the input array.
|
||||
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
|
||||
"""
|
||||
from jax.interpreters import pxla
|
||||
|
||||
# The dtype doesn't matter here. Its only used for creating the
|
||||
# sharding_spec.
|
||||
aval = jax.ShapedArray(shape, np.int32)
|
||||
@ -457,7 +484,6 @@ class PmapSharding(XLACompatibleSharding):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
from jax.interpreters import pxla
|
||||
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
||||
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
||||
|
||||
@ -470,8 +496,6 @@ class PmapSharding(XLACompatibleSharding):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
from jax.interpreters import pxla
|
||||
|
||||
sharded_dim = None
|
||||
for i, s in enumerate(self.sharding_spec.sharding):
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
@ -616,8 +640,6 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
return hash(xc.HloSharding.from_proto(self._op_sharding))
|
||||
|
||||
def __eq__(self, other):
|
||||
from jax.interpreters import pxla
|
||||
|
||||
if not isinstance(other, OpShardingSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
@ -634,8 +656,6 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
from jax.interpreters import pxla
|
||||
|
||||
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
|
||||
if len(aval_shape) < len(num_ways_dim_sharded):
|
||||
raise ValueError(
|
||||
@ -649,8 +669,6 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
from jax.interpreters import pxla
|
||||
|
||||
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
||||
len(self._devices))
|
||||
return dict(safe_zip(self._devices, indices))
|
||||
|
@ -46,7 +46,6 @@ from jax._src.public_test_util import ( # noqa: F401
|
||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.maps import Mesh
|
||||
|
||||
# This submodule includes private test utilities that are not exported to
|
||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||
@ -1000,7 +999,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
if len(local_devices) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
|
||||
with Mesh(mesh_devices, axis_names):
|
||||
with jax.sharding.Mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
def with_mesh_from_kwargs(f):
|
||||
@ -1040,7 +1039,7 @@ def create_global_mesh(mesh_shape, axis_names):
|
||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||
devices = sorted(api.devices(), key=lambda d: d.id)
|
||||
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
||||
global_mesh = Mesh(mesh_devices, axis_names)
|
||||
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
return global_mesh
|
||||
|
||||
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in
|
||||
# 3 months from `jax.experimental.PartitionSpec`.
|
||||
from jax.interpreters.pxla import PartitionSpec as PartitionSpec
|
||||
from jax.experimental.x64_context import (
|
||||
enable_x64 as enable_x64,
|
||||
|
@ -111,6 +111,8 @@ class FrozenDict(abc.Mapping):
|
||||
|
||||
AxisName = core.AxisName
|
||||
ResourceAxisName = AxisName # Different name just for documentation purposes
|
||||
# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in
|
||||
# 3 months from `jax.experimental.maps.Mesh`.
|
||||
Mesh = pxla.Mesh
|
||||
ResourceEnv = pxla.ResourceEnv
|
||||
EMPTY_ENV = pxla.EMPTY_ENV
|
||||
|
@ -68,9 +68,7 @@ from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src import profiler
|
||||
from jax._src import stages
|
||||
from jax._src.sharding import (PmapSharding, NamedSharding, OpShardingSharding,
|
||||
SingleDeviceSharding, XLACompatibleSharding,
|
||||
_get_replicated_op_sharding)
|
||||
from jax._src import sharding as sharding_internal
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.config import flags
|
||||
@ -118,6 +116,7 @@ ShardingSpec = pmap_lib.ShardingSpec
|
||||
MeshAxisName = Any
|
||||
OpShardingType = Any
|
||||
|
||||
PartitionSpec = sharding_internal.PartitionSpec
|
||||
|
||||
def sharding_spec_mesh_shape(self):
|
||||
sharded_axis_sizes = []
|
||||
@ -542,7 +541,7 @@ class OutputType(enum.Enum):
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding: XLACompatibleSharding,
|
||||
sharding: sharding_internal.XLACompatibleSharding,
|
||||
indices: Optional[Tuple[Index, ...]],
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
@ -844,7 +843,7 @@ def _sda_sharding(self):
|
||||
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
|
||||
if has_unstacked:
|
||||
devices = np.array([d.device() for d in self.device_buffers])
|
||||
return PmapSharding(devices, self.sharding_spec)
|
||||
return sharding_internal.PmapSharding(devices, self.sharding_spec)
|
||||
raise NotImplementedError(
|
||||
'SDAs that are the output of pjit/xmap do not have the sharding attribute '
|
||||
'implemented. If you are trying to pass the SDA to pjit/xmap, please '
|
||||
@ -1541,9 +1540,9 @@ class UnloadedPmapExecutable:
|
||||
compiled: Any
|
||||
backend: xb.XlaBackend
|
||||
local_input_avals: Sequence[jax.core.AbstractValue]
|
||||
input_shardings: Sequence[XLACompatibleSharding]
|
||||
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
local_output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[XLACompatibleSharding]
|
||||
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
unordered_effects: List[core.Effect]
|
||||
ordered_effects: List[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
@ -1729,7 +1728,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
|
||||
|
||||
def _get_pmap_sharding(devices, specs):
|
||||
return [PmapSharding(devices, spec) for spec in specs]
|
||||
return [sharding_internal.PmapSharding(devices, spec) for spec in specs]
|
||||
|
||||
|
||||
multi_host_supported_collectives: Set[core.Primitive] = set()
|
||||
@ -1915,11 +1914,11 @@ class ResultsHandler:
|
||||
|
||||
|
||||
def _get_sharding_specs(
|
||||
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
|
||||
shardings: Sequence[sharding_internal.XLACompatibleSharding], avals: Sequence[ShapedArray]
|
||||
) -> Sequence[ShardingSpec]:
|
||||
if all(isinstance(s, PmapSharding) for s in shardings):
|
||||
if all(isinstance(s, sharding_internal.PmapSharding) for s in shardings):
|
||||
return [s.sharding_spec for s in shardings] # type: ignore
|
||||
elif all(isinstance(s, NamedSharding) for s in shardings):
|
||||
elif all(isinstance(s, sharding_internal.NamedSharding) for s in shardings):
|
||||
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
|
||||
aval.ndim, _get_array_mapping(s.spec))
|
||||
for aval, s in safe_zip(avals, shardings)]
|
||||
@ -1930,7 +1929,7 @@ def _get_sharding_specs(
|
||||
|
||||
def local_avals_to_results_handler(
|
||||
unmapped_local_out_avals: Sequence[ShapedArray],
|
||||
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
|
||||
local_shardings: Sequence[sharding_internal.XLACompatibleSharding]) -> ResultsHandler:
|
||||
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
||||
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
||||
handlers = [
|
||||
@ -1942,7 +1941,7 @@ def local_avals_to_results_handler(
|
||||
|
||||
def global_avals_to_results_handler(
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[XLACompatibleSharding],
|
||||
shardings: Sequence[sharding_internal.XLACompatibleSharding],
|
||||
committed: bool,
|
||||
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
|
||||
if config.jax_parallel_functions_output_gda or config.jax_array:
|
||||
@ -1954,10 +1953,10 @@ def global_avals_to_results_handler(
|
||||
return ResultsHandler(handlers, shardings, global_out_avals)
|
||||
else:
|
||||
# This path is taken when the outputs are SDAs.
|
||||
assert all(isinstance(s, NamedSharding) for s in shardings)
|
||||
assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings)
|
||||
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
|
||||
for aval, s in safe_zip(global_out_avals, shardings)]
|
||||
local_shardings = [NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
|
||||
local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
|
||||
for s in shardings]
|
||||
return local_avals_to_results_handler(local_out_avals, local_shardings)
|
||||
|
||||
@ -2702,52 +2701,16 @@ TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
|
||||
def _check_if_any_auto(
|
||||
shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource,
|
||||
_UnspecifiedValue]]) -> bool:
|
||||
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
|
||||
_AUTOAxisResource, _UnspecifiedValue]]) -> bool:
|
||||
for s in shardings:
|
||||
if _is_auto(s):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _UnconstrainedPartitionSingleton:
|
||||
|
||||
def __str__(self):
|
||||
return "UNCONSTRAINED"
|
||||
|
||||
|
||||
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
||||
# which the user wants XLA to assign the best partitioning.
|
||||
# TODO(yashkatariya): May rename to AUTO.
|
||||
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||
|
||||
|
||||
class PartitionSpec(tuple):
|
||||
"""Tuple of integer specifying how a value should be partitioned.
|
||||
|
||||
Each integer corresponds to how many ways a dimension is partitioned. We
|
||||
create a separate class for this so JAX's pytree utilities can distinguish it
|
||||
from a tuple that should be treated as a pytree.
|
||||
"""
|
||||
|
||||
# A sentinel value representing a dim is unconstrained.
|
||||
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
|
||||
|
||||
def __init__(self, *partitions):
|
||||
pass
|
||||
|
||||
def __new__(cls, *partitions):
|
||||
return tuple.__new__(PartitionSpec, partitions)
|
||||
|
||||
def __repr__(self):
|
||||
return "PartitionSpec%s" % tuple.__repr__(self)
|
||||
|
||||
def __reduce__(self):
|
||||
return (PartitionSpec, tuple(self))
|
||||
|
||||
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[XLACompatibleSharding],
|
||||
shardings: Iterable[sharding_internal.XLACompatibleSharding],
|
||||
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
||||
from jax._src.api import local_devices
|
||||
|
||||
@ -2800,8 +2763,8 @@ def lower_sharding_computation(
|
||||
fun: lu.WrappedFun,
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]],
|
||||
out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
|
||||
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]],
|
||||
out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
in_is_global: Sequence[bool],
|
||||
@ -2847,7 +2810,7 @@ def lower_sharding_computation(
|
||||
jaxpr_sharding or
|
||||
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
|
||||
|
||||
in_shardings = tuple(OpShardingSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
|
||||
if _is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -3019,8 +2982,8 @@ def lower_mesh_computation(
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
mesh: Mesh,
|
||||
in_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource]],
|
||||
out_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource,
|
||||
in_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource]],
|
||||
out_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource,
|
||||
_UnspecifiedValue]],
|
||||
donated_invars: Sequence[bool],
|
||||
spmd_lowering: bool,
|
||||
@ -3241,8 +3204,8 @@ class MeshComputation(stages.XlaLowering):
|
||||
|
||||
def _get_input_metadata(
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
||||
in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
||||
Sequence[ShapedArray]]:
|
||||
avals, shardings = _get_normalized_avals_and_shardings(
|
||||
global_in_avals, in_shardings, in_is_global)
|
||||
@ -3251,8 +3214,8 @@ def _get_input_metadata(
|
||||
|
||||
def _get_normalized_avals_and_shardings(
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]:
|
||||
in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[ShapedArray], Sequence[sharding_internal.XLACompatibleSharding]]:
|
||||
avals = []
|
||||
shardings = []
|
||||
|
||||
@ -3262,9 +3225,9 @@ def _get_normalized_avals_and_shardings(
|
||||
aval = gaval
|
||||
in_sharding = i
|
||||
else:
|
||||
assert isinstance(i, NamedSharding)
|
||||
assert isinstance(i, sharding_internal.NamedSharding)
|
||||
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
||||
in_sharding = NamedSharding(i.mesh.local_mesh, i.spec)
|
||||
in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
|
||||
avals.append(aval)
|
||||
shardings.append(in_sharding)
|
||||
|
||||
@ -3272,7 +3235,7 @@ def _get_normalized_avals_and_shardings(
|
||||
|
||||
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding]
|
||||
avals: Sequence[ShapedArray], shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
) -> Sequence[Tuple[Optional[Index], ...]]:
|
||||
|
||||
input_indices = []
|
||||
@ -3308,13 +3271,17 @@ def _get_op_sharding_shardings_from_executable(
|
||||
# just return SingleDeviceShardings since we know the computation is running
|
||||
# only on 1 device.
|
||||
if len(device_assignment) == 1:
|
||||
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
|
||||
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])
|
||||
return ([sharding_internal.SingleDeviceSharding(device_assignment[0])
|
||||
for _ in range(num_in_avals)],
|
||||
[sharding_internal.SingleDeviceSharding(device_assignment[0])
|
||||
for _ in range(num_out_avals)])
|
||||
|
||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||
|
||||
in_shardings_xla = [OpShardingSharding(device_assignment, i) for i in in_op_shardings]
|
||||
out_shardings_xla = [OpShardingSharding(device_assignment, o) for o in out_op_shardings]
|
||||
in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i)
|
||||
for i in in_op_shardings]
|
||||
out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o)
|
||||
for o in out_op_shardings]
|
||||
# This condition happens when all the elements in the output tuple have the
|
||||
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
||||
# put the sharding on ROOT instead of the tuple.
|
||||
@ -3332,8 +3299,8 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
|
||||
from jax.experimental import pjit
|
||||
|
||||
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
||||
return ([NamedSharding(mesh, i) for i in in_pspec],
|
||||
[NamedSharding(mesh, o) for o in out_pspec])
|
||||
return ([sharding_internal.NamedSharding(mesh, i) for i in in_pspec],
|
||||
[sharding_internal.NamedSharding(mesh, o) for o in out_pspec])
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -3342,9 +3309,9 @@ class UnloadedMeshExecutable:
|
||||
device_assignment: Sequence[xc.Device]
|
||||
backend: xb.XlaBackend
|
||||
input_avals: Sequence[ShapedArray]
|
||||
input_shardings: Sequence[XLACompatibleSharding]
|
||||
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[XLACompatibleSharding]
|
||||
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
committed: bool
|
||||
are_out_shardings_from_xla: Sequence[bool]
|
||||
pmap_nreps: int
|
||||
@ -3396,8 +3363,8 @@ class UnloadedMeshExecutable:
|
||||
mesh: Optional[Mesh],
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]],
|
||||
out_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource,
|
||||
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource]],
|
||||
out_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource,
|
||||
_UnspecifiedValue]],
|
||||
spmd_lowering: bool,
|
||||
tuple_args: bool,
|
||||
@ -3638,9 +3605,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
|
||||
def _out_shardings_for_trivial(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||
in_shardings: Sequence[XLACompatibleSharding],
|
||||
in_shardings: Sequence[sharding_internal.XLACompatibleSharding],
|
||||
device_assignment: Sequence[xc.Device],
|
||||
) -> List[XLACompatibleSharding]:
|
||||
) -> List[sharding_internal.XLACompatibleSharding]:
|
||||
# For each jaxpr output, compute a Sharding by:
|
||||
# * if the output is a forwarded input, get the corresponding in_sharding;
|
||||
# * if the output is a constant Array, get its .sharding attribute;
|
||||
@ -3648,9 +3615,9 @@ def _out_shardings_for_trivial(
|
||||
# a replicated sharding
|
||||
from jax._src import array
|
||||
|
||||
rep = OpShardingSharding(
|
||||
device_assignment, _get_replicated_op_sharding())
|
||||
shardings: Dict[core.Var, XLACompatibleSharding] = {}
|
||||
rep = sharding_internal.OpShardingSharding(
|
||||
device_assignment, sharding_internal._get_replicated_op_sharding())
|
||||
shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {}
|
||||
for constvar, constval in zip(jaxpr.constvars, consts):
|
||||
if isinstance(constval, array.ArrayImpl):
|
||||
shardings[constvar] = constval.sharding
|
||||
@ -3744,7 +3711,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
|
||||
@lru_cache()
|
||||
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
||||
return NamedSharding(mesh, pspec, parsed_pspec)
|
||||
return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec)
|
||||
|
||||
|
||||
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
|
@ -19,8 +19,11 @@ from jax._src.sharding import (
|
||||
MeshPspecSharding as MeshPspecSharding,
|
||||
# New name of MeshPspecSharding to match PositionalSharding below.
|
||||
NamedSharding as NamedSharding,
|
||||
PartitionSpec as PartitionSpec,
|
||||
SingleDeviceSharding as SingleDeviceSharding,
|
||||
PmapSharding as PmapSharding,
|
||||
OpShardingSharding as OpShardingSharding,
|
||||
PositionalSharding as PositionalSharding,
|
||||
)
|
||||
|
||||
from jax.interpreters.pxla import Mesh as Mesh
|
||||
|
@ -36,9 +36,8 @@ from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
from jax import prng
|
||||
# TODO(skye): do we still wanna call this PartitionSpec?
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import multihost_utils
|
||||
@ -328,7 +327,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def testTwoMeshAxisSharding(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=P(('x', 'y'),),
|
||||
out_axis_resources=P(('x', 'y'),))
|
||||
out_axis_resources=jax.sharding.PartitionSpec(('x', 'y'),))
|
||||
def f(x, y):
|
||||
return x @ y
|
||||
|
||||
@ -2210,7 +2209,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
[devices[4], devices[6]],
|
||||
[devices[7], devices[5]]])
|
||||
shape = (8, 2)
|
||||
mesh = maps.Mesh(mesh_devices, ('x', 'y'))
|
||||
mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user