mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move PartitionSpec
and Mesh
out of experimental and into the sharding
namespace. The new API endpoint is jax.sharding.PartitionSpec
and jax.sharding.Mesh
.
PiperOrigin-RevId: 492358238
This commit is contained in:
parent
ed9519dadf
commit
934bc4e1b3
@ -28,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
[Parallelism with
|
[Parallelism with
|
||||||
JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html)
|
JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html)
|
||||||
tutorial to understand the new concepts.
|
tutorial to understand the new concepts.
|
||||||
|
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
|
||||||
|
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
|
||||||
|
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
|
||||||
|
deprecated and will be removed in 3 months.
|
||||||
|
|
||||||
|
|
||||||
## jaxlib 0.4.0
|
## jaxlib 0.4.0
|
||||||
|
@ -18,24 +18,21 @@ import functools
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
import operator as op
|
import operator as op
|
||||||
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||||
FrozenSet, Union, cast, TYPE_CHECKING)
|
FrozenSet, Union, cast)
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
|
from jax.interpreters import pxla
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
Shape = Tuple[int, ...]
|
Shape = Tuple[int, ...]
|
||||||
Device = xc.Device
|
Device = xc.Device
|
||||||
Index = Tuple[slice, ...]
|
Index = Tuple[slice, ...]
|
||||||
XLADeviceAssignment = Sequence[Device]
|
XLADeviceAssignment = Sequence[Device]
|
||||||
|
|
||||||
|
|
||||||
@use_cpp_class(xc.Sharding if xc._version >= 94 else None)
|
@use_cpp_class(xc.Sharding if xc._version >= 94 else None)
|
||||||
class Sharding(metaclass=abc.ABCMeta):
|
class Sharding(metaclass=abc.ABCMeta):
|
||||||
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
||||||
@ -137,9 +134,6 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=4096)
|
@functools.lru_cache(maxsize=4096)
|
||||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||||
# TODO(https://github.com/google/jax/issues/12016): Remove the local import
|
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
|
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
|
||||||
if pxla.is_op_sharding_replicated(op_sharding):
|
if pxla.is_op_sharding_replicated(op_sharding):
|
||||||
return global_shape
|
return global_shape
|
||||||
@ -205,6 +199,42 @@ def _enable_cpp_named_sharding():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _UnconstrainedPartitionSingleton:
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "UNCONSTRAINED"
|
||||||
|
|
||||||
|
|
||||||
|
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
||||||
|
# which the user wants XLA to assign the best partitioning.
|
||||||
|
# TODO(yashkatariya): May rename to AUTO.
|
||||||
|
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||||
|
|
||||||
|
|
||||||
|
class PartitionSpec(tuple):
|
||||||
|
"""Tuple of integer specifying how a value should be partitioned.
|
||||||
|
|
||||||
|
Each integer corresponds to how many ways a dimension is partitioned. We
|
||||||
|
create a separate class for this so JAX's pytree utilities can distinguish it
|
||||||
|
from a tuple that should be treated as a pytree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A sentinel value representing a dim is unconstrained.
|
||||||
|
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
|
||||||
|
|
||||||
|
def __init__(self, *partitions):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __new__(cls, *partitions):
|
||||||
|
return tuple.__new__(PartitionSpec, partitions)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "PartitionSpec%s" % tuple.__repr__(self)
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (PartitionSpec, tuple(self))
|
||||||
|
|
||||||
|
|
||||||
@use_cpp_class(_enable_cpp_named_sharding())
|
@use_cpp_class(_enable_cpp_named_sharding())
|
||||||
class NamedSharding(XLACompatibleSharding):
|
class NamedSharding(XLACompatibleSharding):
|
||||||
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
||||||
@ -241,7 +271,7 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
@use_cpp_method
|
@use_cpp_method
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
|
self, mesh: pxla.Mesh, spec: PartitionSpec, _parsed_pspec = None):
|
||||||
|
|
||||||
self.mesh = mesh
|
self.mesh = mesh
|
||||||
self.spec = spec
|
self.spec = spec
|
||||||
@ -309,7 +339,6 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
||||||
) -> xc.OpSharding:
|
) -> xc.OpSharding:
|
||||||
from jax.experimental.pjit import get_array_mapping
|
from jax.experimental.pjit import get_array_mapping
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||||
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
||||||
@ -432,8 +461,6 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
shape: The shape of the input array.
|
shape: The shape of the input array.
|
||||||
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
|
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
# The dtype doesn't matter here. Its only used for creating the
|
# The dtype doesn't matter here. Its only used for creating the
|
||||||
# sharding_spec.
|
# sharding_spec.
|
||||||
aval = jax.ShapedArray(shape, np.int32)
|
aval = jax.ShapedArray(shape, np.int32)
|
||||||
@ -457,7 +484,6 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=4096)
|
@functools.lru_cache(maxsize=4096)
|
||||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||||
from jax.interpreters import pxla
|
|
||||||
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
||||||
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -470,8 +496,6 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=4096)
|
@functools.lru_cache(maxsize=4096)
|
||||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
sharded_dim = None
|
sharded_dim = None
|
||||||
for i, s in enumerate(self.sharding_spec.sharding):
|
for i, s in enumerate(self.sharding_spec.sharding):
|
||||||
if isinstance(s, pxla.Unstacked):
|
if isinstance(s, pxla.Unstacked):
|
||||||
@ -616,8 +640,6 @@ class OpShardingSharding(XLACompatibleSharding):
|
|||||||
return hash(xc.HloSharding.from_proto(self._op_sharding))
|
return hash(xc.HloSharding.from_proto(self._op_sharding))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
if not isinstance(other, OpShardingSharding):
|
if not isinstance(other, OpShardingSharding):
|
||||||
return False
|
return False
|
||||||
if id(self) == id(other):
|
if id(self) == id(other):
|
||||||
@ -634,8 +656,6 @@ class OpShardingSharding(XLACompatibleSharding):
|
|||||||
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
||||||
|
|
||||||
def is_compatible_aval(self, aval_shape: Shape):
|
def is_compatible_aval(self, aval_shape: Shape):
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
|
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
|
||||||
if len(aval_shape) < len(num_ways_dim_sharded):
|
if len(aval_shape) < len(num_ways_dim_sharded):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -649,8 +669,6 @@ class OpShardingSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=4096)
|
@functools.lru_cache(maxsize=4096)
|
||||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||||
from jax.interpreters import pxla
|
|
||||||
|
|
||||||
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
||||||
len(self._devices))
|
len(self._devices))
|
||||||
return dict(safe_zip(self._devices, indices))
|
return dict(safe_zip(self._devices, indices))
|
||||||
|
@ -46,7 +46,6 @@ from jax._src.public_test_util import ( # noqa: F401
|
|||||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||||
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
|
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
|
||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
from jax.experimental.maps import Mesh
|
|
||||||
|
|
||||||
# This submodule includes private test utilities that are not exported to
|
# This submodule includes private test utilities that are not exported to
|
||||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||||
@ -1000,7 +999,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
|||||||
if len(local_devices) < size:
|
if len(local_devices) < size:
|
||||||
raise unittest.SkipTest(f"Test requires {size} local devices")
|
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||||
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
|
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
|
||||||
with Mesh(mesh_devices, axis_names):
|
with jax.sharding.Mesh(mesh_devices, axis_names):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def with_mesh_from_kwargs(f):
|
def with_mesh_from_kwargs(f):
|
||||||
@ -1040,7 +1039,7 @@ def create_global_mesh(mesh_shape, axis_names):
|
|||||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||||
devices = sorted(api.devices(), key=lambda d: d.id)
|
devices = sorted(api.devices(), key=lambda d: d.id)
|
||||||
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
||||||
global_mesh = Mesh(mesh_devices, axis_names)
|
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
||||||
return global_mesh
|
return global_mesh
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in
|
||||||
|
# 3 months from `jax.experimental.PartitionSpec`.
|
||||||
from jax.interpreters.pxla import PartitionSpec as PartitionSpec
|
from jax.interpreters.pxla import PartitionSpec as PartitionSpec
|
||||||
from jax.experimental.x64_context import (
|
from jax.experimental.x64_context import (
|
||||||
enable_x64 as enable_x64,
|
enable_x64 as enable_x64,
|
||||||
|
@ -111,6 +111,8 @@ class FrozenDict(abc.Mapping):
|
|||||||
|
|
||||||
AxisName = core.AxisName
|
AxisName = core.AxisName
|
||||||
ResourceAxisName = AxisName # Different name just for documentation purposes
|
ResourceAxisName = AxisName # Different name just for documentation purposes
|
||||||
|
# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in
|
||||||
|
# 3 months from `jax.experimental.maps.Mesh`.
|
||||||
Mesh = pxla.Mesh
|
Mesh = pxla.Mesh
|
||||||
ResourceEnv = pxla.ResourceEnv
|
ResourceEnv = pxla.ResourceEnv
|
||||||
EMPTY_ENV = pxla.EMPTY_ENV
|
EMPTY_ENV = pxla.EMPTY_ENV
|
||||||
|
@ -68,9 +68,7 @@ from jax._src import util
|
|||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import profiler
|
from jax._src import profiler
|
||||||
from jax._src import stages
|
from jax._src import stages
|
||||||
from jax._src.sharding import (PmapSharding, NamedSharding, OpShardingSharding,
|
from jax._src import sharding as sharding_internal
|
||||||
SingleDeviceSharding, XLACompatibleSharding,
|
|
||||||
_get_replicated_op_sharding)
|
|
||||||
from jax._src.abstract_arrays import array_types
|
from jax._src.abstract_arrays import array_types
|
||||||
from jax._src.config import config
|
from jax._src.config import config
|
||||||
from jax._src.config import flags
|
from jax._src.config import flags
|
||||||
@ -118,6 +116,7 @@ ShardingSpec = pmap_lib.ShardingSpec
|
|||||||
MeshAxisName = Any
|
MeshAxisName = Any
|
||||||
OpShardingType = Any
|
OpShardingType = Any
|
||||||
|
|
||||||
|
PartitionSpec = sharding_internal.PartitionSpec
|
||||||
|
|
||||||
def sharding_spec_mesh_shape(self):
|
def sharding_spec_mesh_shape(self):
|
||||||
sharded_axis_sizes = []
|
sharded_axis_sizes = []
|
||||||
@ -542,7 +541,7 @@ class OutputType(enum.Enum):
|
|||||||
|
|
||||||
def local_aval_to_result_handler(
|
def local_aval_to_result_handler(
|
||||||
aval: core.AbstractValue,
|
aval: core.AbstractValue,
|
||||||
sharding: XLACompatibleSharding,
|
sharding: sharding_internal.XLACompatibleSharding,
|
||||||
indices: Optional[Tuple[Index, ...]],
|
indices: Optional[Tuple[Index, ...]],
|
||||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||||
"""Returns a function for handling the raw buffers of a single output aval.
|
"""Returns a function for handling the raw buffers of a single output aval.
|
||||||
@ -844,7 +843,7 @@ def _sda_sharding(self):
|
|||||||
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
|
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
|
||||||
if has_unstacked:
|
if has_unstacked:
|
||||||
devices = np.array([d.device() for d in self.device_buffers])
|
devices = np.array([d.device() for d in self.device_buffers])
|
||||||
return PmapSharding(devices, self.sharding_spec)
|
return sharding_internal.PmapSharding(devices, self.sharding_spec)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'SDAs that are the output of pjit/xmap do not have the sharding attribute '
|
'SDAs that are the output of pjit/xmap do not have the sharding attribute '
|
||||||
'implemented. If you are trying to pass the SDA to pjit/xmap, please '
|
'implemented. If you are trying to pass the SDA to pjit/xmap, please '
|
||||||
@ -1541,9 +1540,9 @@ class UnloadedPmapExecutable:
|
|||||||
compiled: Any
|
compiled: Any
|
||||||
backend: xb.XlaBackend
|
backend: xb.XlaBackend
|
||||||
local_input_avals: Sequence[jax.core.AbstractValue]
|
local_input_avals: Sequence[jax.core.AbstractValue]
|
||||||
input_shardings: Sequence[XLACompatibleSharding]
|
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||||
local_output_avals: Sequence[ShapedArray]
|
local_output_avals: Sequence[ShapedArray]
|
||||||
output_shardings: Sequence[XLACompatibleSharding]
|
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||||
unordered_effects: List[core.Effect]
|
unordered_effects: List[core.Effect]
|
||||||
ordered_effects: List[core.Effect]
|
ordered_effects: List[core.Effect]
|
||||||
keepalive: Sequence[Any]
|
keepalive: Sequence[Any]
|
||||||
@ -1729,7 +1728,7 @@ class PmapExecutable(stages.XlaExecutable):
|
|||||||
|
|
||||||
|
|
||||||
def _get_pmap_sharding(devices, specs):
|
def _get_pmap_sharding(devices, specs):
|
||||||
return [PmapSharding(devices, spec) for spec in specs]
|
return [sharding_internal.PmapSharding(devices, spec) for spec in specs]
|
||||||
|
|
||||||
|
|
||||||
multi_host_supported_collectives: Set[core.Primitive] = set()
|
multi_host_supported_collectives: Set[core.Primitive] = set()
|
||||||
@ -1915,11 +1914,11 @@ class ResultsHandler:
|
|||||||
|
|
||||||
|
|
||||||
def _get_sharding_specs(
|
def _get_sharding_specs(
|
||||||
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
|
shardings: Sequence[sharding_internal.XLACompatibleSharding], avals: Sequence[ShapedArray]
|
||||||
) -> Sequence[ShardingSpec]:
|
) -> Sequence[ShardingSpec]:
|
||||||
if all(isinstance(s, PmapSharding) for s in shardings):
|
if all(isinstance(s, sharding_internal.PmapSharding) for s in shardings):
|
||||||
return [s.sharding_spec for s in shardings] # type: ignore
|
return [s.sharding_spec for s in shardings] # type: ignore
|
||||||
elif all(isinstance(s, NamedSharding) for s in shardings):
|
elif all(isinstance(s, sharding_internal.NamedSharding) for s in shardings):
|
||||||
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
|
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
|
||||||
aval.ndim, _get_array_mapping(s.spec))
|
aval.ndim, _get_array_mapping(s.spec))
|
||||||
for aval, s in safe_zip(avals, shardings)]
|
for aval, s in safe_zip(avals, shardings)]
|
||||||
@ -1930,7 +1929,7 @@ def _get_sharding_specs(
|
|||||||
|
|
||||||
def local_avals_to_results_handler(
|
def local_avals_to_results_handler(
|
||||||
unmapped_local_out_avals: Sequence[ShapedArray],
|
unmapped_local_out_avals: Sequence[ShapedArray],
|
||||||
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
|
local_shardings: Sequence[sharding_internal.XLACompatibleSharding]) -> ResultsHandler:
|
||||||
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
||||||
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
||||||
handlers = [
|
handlers = [
|
||||||
@ -1942,7 +1941,7 @@ def local_avals_to_results_handler(
|
|||||||
|
|
||||||
def global_avals_to_results_handler(
|
def global_avals_to_results_handler(
|
||||||
global_out_avals: Sequence[ShapedArray],
|
global_out_avals: Sequence[ShapedArray],
|
||||||
shardings: Sequence[XLACompatibleSharding],
|
shardings: Sequence[sharding_internal.XLACompatibleSharding],
|
||||||
committed: bool,
|
committed: bool,
|
||||||
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
|
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
|
||||||
if config.jax_parallel_functions_output_gda or config.jax_array:
|
if config.jax_parallel_functions_output_gda or config.jax_array:
|
||||||
@ -1954,10 +1953,10 @@ def global_avals_to_results_handler(
|
|||||||
return ResultsHandler(handlers, shardings, global_out_avals)
|
return ResultsHandler(handlers, shardings, global_out_avals)
|
||||||
else:
|
else:
|
||||||
# This path is taken when the outputs are SDAs.
|
# This path is taken when the outputs are SDAs.
|
||||||
assert all(isinstance(s, NamedSharding) for s in shardings)
|
assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings)
|
||||||
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
|
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
|
||||||
for aval, s in safe_zip(global_out_avals, shardings)]
|
for aval, s in safe_zip(global_out_avals, shardings)]
|
||||||
local_shardings = [NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
|
local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
|
||||||
for s in shardings]
|
for s in shardings]
|
||||||
return local_avals_to_results_handler(local_out_avals, local_shardings)
|
return local_avals_to_results_handler(local_out_avals, local_shardings)
|
||||||
|
|
||||||
@ -2702,52 +2701,16 @@ TilingMethod = Union[TileVectorize, TileManual]
|
|||||||
|
|
||||||
|
|
||||||
def _check_if_any_auto(
|
def _check_if_any_auto(
|
||||||
shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource,
|
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
|
||||||
_UnspecifiedValue]]) -> bool:
|
_AUTOAxisResource, _UnspecifiedValue]]) -> bool:
|
||||||
for s in shardings:
|
for s in shardings:
|
||||||
if _is_auto(s):
|
if _is_auto(s):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _UnconstrainedPartitionSingleton:
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "UNCONSTRAINED"
|
|
||||||
|
|
||||||
|
|
||||||
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
|
||||||
# which the user wants XLA to assign the best partitioning.
|
|
||||||
# TODO(yashkatariya): May rename to AUTO.
|
|
||||||
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
|
||||||
|
|
||||||
|
|
||||||
class PartitionSpec(tuple):
|
|
||||||
"""Tuple of integer specifying how a value should be partitioned.
|
|
||||||
|
|
||||||
Each integer corresponds to how many ways a dimension is partitioned. We
|
|
||||||
create a separate class for this so JAX's pytree utilities can distinguish it
|
|
||||||
from a tuple that should be treated as a pytree.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# A sentinel value representing a dim is unconstrained.
|
|
||||||
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
|
|
||||||
|
|
||||||
def __init__(self, *partitions):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __new__(cls, *partitions):
|
|
||||||
return tuple.__new__(PartitionSpec, partitions)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "PartitionSpec%s" % tuple.__repr__(self)
|
|
||||||
|
|
||||||
def __reduce__(self):
|
|
||||||
return (PartitionSpec, tuple(self))
|
|
||||||
|
|
||||||
|
|
||||||
def _get_and_check_device_assignment(
|
def _get_and_check_device_assignment(
|
||||||
shardings: Iterable[XLACompatibleSharding],
|
shardings: Iterable[sharding_internal.XLACompatibleSharding],
|
||||||
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
||||||
from jax._src.api import local_devices
|
from jax._src.api import local_devices
|
||||||
|
|
||||||
@ -2800,8 +2763,8 @@ def lower_sharding_computation(
|
|||||||
fun: lu.WrappedFun,
|
fun: lu.WrappedFun,
|
||||||
api_name: str,
|
api_name: str,
|
||||||
fun_name: str,
|
fun_name: str,
|
||||||
in_shardings: Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]],
|
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]],
|
||||||
out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
|
out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
|
||||||
donated_invars: Sequence[bool],
|
donated_invars: Sequence[bool],
|
||||||
global_in_avals: Sequence[core.ShapedArray],
|
global_in_avals: Sequence[core.ShapedArray],
|
||||||
in_is_global: Sequence[bool],
|
in_is_global: Sequence[bool],
|
||||||
@ -2847,7 +2810,7 @@ def lower_sharding_computation(
|
|||||||
jaxpr_sharding or
|
jaxpr_sharding or
|
||||||
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
|
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
|
||||||
|
|
||||||
in_shardings = tuple(OpShardingSharding.get_replicated(device_assignment)
|
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
|
||||||
if _is_unspecified(i) else i for i in in_shardings)
|
if _is_unspecified(i) else i for i in in_shardings)
|
||||||
|
|
||||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||||
@ -3019,8 +2982,8 @@ def lower_mesh_computation(
|
|||||||
api_name: str,
|
api_name: str,
|
||||||
fun_name: str,
|
fun_name: str,
|
||||||
mesh: Mesh,
|
mesh: Mesh,
|
||||||
in_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource]],
|
in_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource]],
|
||||||
out_shardings: Sequence[Union[NamedSharding, _AUTOAxisResource,
|
out_shardings: Sequence[Union[sharding_internal.NamedSharding, _AUTOAxisResource,
|
||||||
_UnspecifiedValue]],
|
_UnspecifiedValue]],
|
||||||
donated_invars: Sequence[bool],
|
donated_invars: Sequence[bool],
|
||||||
spmd_lowering: bool,
|
spmd_lowering: bool,
|
||||||
@ -3241,8 +3204,8 @@ class MeshComputation(stages.XlaLowering):
|
|||||||
|
|
||||||
def _get_input_metadata(
|
def _get_input_metadata(
|
||||||
global_in_avals: Sequence[ShapedArray],
|
global_in_avals: Sequence[ShapedArray],
|
||||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||||
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
||||||
Sequence[ShapedArray]]:
|
Sequence[ShapedArray]]:
|
||||||
avals, shardings = _get_normalized_avals_and_shardings(
|
avals, shardings = _get_normalized_avals_and_shardings(
|
||||||
global_in_avals, in_shardings, in_is_global)
|
global_in_avals, in_shardings, in_is_global)
|
||||||
@ -3251,8 +3214,8 @@ def _get_input_metadata(
|
|||||||
|
|
||||||
def _get_normalized_avals_and_shardings(
|
def _get_normalized_avals_and_shardings(
|
||||||
global_in_avals: Sequence[ShapedArray],
|
global_in_avals: Sequence[ShapedArray],
|
||||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
in_shardings: Sequence[sharding_internal.XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||||
) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]:
|
) -> Tuple[Sequence[ShapedArray], Sequence[sharding_internal.XLACompatibleSharding]]:
|
||||||
avals = []
|
avals = []
|
||||||
shardings = []
|
shardings = []
|
||||||
|
|
||||||
@ -3262,9 +3225,9 @@ def _get_normalized_avals_and_shardings(
|
|||||||
aval = gaval
|
aval = gaval
|
||||||
in_sharding = i
|
in_sharding = i
|
||||||
else:
|
else:
|
||||||
assert isinstance(i, NamedSharding)
|
assert isinstance(i, sharding_internal.NamedSharding)
|
||||||
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
||||||
in_sharding = NamedSharding(i.mesh.local_mesh, i.spec)
|
in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
|
||||||
avals.append(aval)
|
avals.append(aval)
|
||||||
shardings.append(in_sharding)
|
shardings.append(in_sharding)
|
||||||
|
|
||||||
@ -3272,7 +3235,7 @@ def _get_normalized_avals_and_shardings(
|
|||||||
|
|
||||||
|
|
||||||
def _get_input_indices(
|
def _get_input_indices(
|
||||||
avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding]
|
avals: Sequence[ShapedArray], shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||||
) -> Sequence[Tuple[Optional[Index], ...]]:
|
) -> Sequence[Tuple[Optional[Index], ...]]:
|
||||||
|
|
||||||
input_indices = []
|
input_indices = []
|
||||||
@ -3308,13 +3271,17 @@ def _get_op_sharding_shardings_from_executable(
|
|||||||
# just return SingleDeviceShardings since we know the computation is running
|
# just return SingleDeviceShardings since we know the computation is running
|
||||||
# only on 1 device.
|
# only on 1 device.
|
||||||
if len(device_assignment) == 1:
|
if len(device_assignment) == 1:
|
||||||
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
|
return ([sharding_internal.SingleDeviceSharding(device_assignment[0])
|
||||||
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])
|
for _ in range(num_in_avals)],
|
||||||
|
[sharding_internal.SingleDeviceSharding(device_assignment[0])
|
||||||
|
for _ in range(num_out_avals)])
|
||||||
|
|
||||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||||
|
|
||||||
in_shardings_xla = [OpShardingSharding(device_assignment, i) for i in in_op_shardings]
|
in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i)
|
||||||
out_shardings_xla = [OpShardingSharding(device_assignment, o) for o in out_op_shardings]
|
for i in in_op_shardings]
|
||||||
|
out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o)
|
||||||
|
for o in out_op_shardings]
|
||||||
# This condition happens when all the elements in the output tuple have the
|
# This condition happens when all the elements in the output tuple have the
|
||||||
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
||||||
# put the sharding on ROOT instead of the tuple.
|
# put the sharding on ROOT instead of the tuple.
|
||||||
@ -3332,8 +3299,8 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
|
|||||||
from jax.experimental import pjit
|
from jax.experimental import pjit
|
||||||
|
|
||||||
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
||||||
return ([NamedSharding(mesh, i) for i in in_pspec],
|
return ([sharding_internal.NamedSharding(mesh, i) for i in in_pspec],
|
||||||
[NamedSharding(mesh, o) for o in out_pspec])
|
[sharding_internal.NamedSharding(mesh, o) for o in out_pspec])
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -3342,9 +3309,9 @@ class UnloadedMeshExecutable:
|
|||||||
device_assignment: Sequence[xc.Device]
|
device_assignment: Sequence[xc.Device]
|
||||||
backend: xb.XlaBackend
|
backend: xb.XlaBackend
|
||||||
input_avals: Sequence[ShapedArray]
|
input_avals: Sequence[ShapedArray]
|
||||||
input_shardings: Sequence[XLACompatibleSharding]
|
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||||
output_avals: Sequence[ShapedArray]
|
output_avals: Sequence[ShapedArray]
|
||||||
output_shardings: Sequence[XLACompatibleSharding]
|
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||||
committed: bool
|
committed: bool
|
||||||
are_out_shardings_from_xla: Sequence[bool]
|
are_out_shardings_from_xla: Sequence[bool]
|
||||||
pmap_nreps: int
|
pmap_nreps: int
|
||||||
@ -3396,8 +3363,8 @@ class UnloadedMeshExecutable:
|
|||||||
mesh: Optional[Mesh],
|
mesh: Optional[Mesh],
|
||||||
global_in_avals: Sequence[ShapedArray],
|
global_in_avals: Sequence[ShapedArray],
|
||||||
global_out_avals: Sequence[ShapedArray],
|
global_out_avals: Sequence[ShapedArray],
|
||||||
in_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]],
|
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource]],
|
||||||
out_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource,
|
out_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, _AUTOAxisResource,
|
||||||
_UnspecifiedValue]],
|
_UnspecifiedValue]],
|
||||||
spmd_lowering: bool,
|
spmd_lowering: bool,
|
||||||
tuple_args: bool,
|
tuple_args: bool,
|
||||||
@ -3638,9 +3605,9 @@ class MeshExecutable(stages.XlaExecutable):
|
|||||||
|
|
||||||
def _out_shardings_for_trivial(
|
def _out_shardings_for_trivial(
|
||||||
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||||
in_shardings: Sequence[XLACompatibleSharding],
|
in_shardings: Sequence[sharding_internal.XLACompatibleSharding],
|
||||||
device_assignment: Sequence[xc.Device],
|
device_assignment: Sequence[xc.Device],
|
||||||
) -> List[XLACompatibleSharding]:
|
) -> List[sharding_internal.XLACompatibleSharding]:
|
||||||
# For each jaxpr output, compute a Sharding by:
|
# For each jaxpr output, compute a Sharding by:
|
||||||
# * if the output is a forwarded input, get the corresponding in_sharding;
|
# * if the output is a forwarded input, get the corresponding in_sharding;
|
||||||
# * if the output is a constant Array, get its .sharding attribute;
|
# * if the output is a constant Array, get its .sharding attribute;
|
||||||
@ -3648,9 +3615,9 @@ def _out_shardings_for_trivial(
|
|||||||
# a replicated sharding
|
# a replicated sharding
|
||||||
from jax._src import array
|
from jax._src import array
|
||||||
|
|
||||||
rep = OpShardingSharding(
|
rep = sharding_internal.OpShardingSharding(
|
||||||
device_assignment, _get_replicated_op_sharding())
|
device_assignment, sharding_internal._get_replicated_op_sharding())
|
||||||
shardings: Dict[core.Var, XLACompatibleSharding] = {}
|
shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {}
|
||||||
for constvar, constval in zip(jaxpr.constvars, consts):
|
for constvar, constval in zip(jaxpr.constvars, consts):
|
||||||
if isinstance(constval, array.ArrayImpl):
|
if isinstance(constval, array.ArrayImpl):
|
||||||
shardings[constvar] = constval.sharding
|
shardings[constvar] = constval.sharding
|
||||||
@ -3744,7 +3711,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
||||||
return NamedSharding(mesh, pspec, parsed_pspec)
|
return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec)
|
||||||
|
|
||||||
|
|
||||||
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||||
|
@ -19,8 +19,11 @@ from jax._src.sharding import (
|
|||||||
MeshPspecSharding as MeshPspecSharding,
|
MeshPspecSharding as MeshPspecSharding,
|
||||||
# New name of MeshPspecSharding to match PositionalSharding below.
|
# New name of MeshPspecSharding to match PositionalSharding below.
|
||||||
NamedSharding as NamedSharding,
|
NamedSharding as NamedSharding,
|
||||||
|
PartitionSpec as PartitionSpec,
|
||||||
SingleDeviceSharding as SingleDeviceSharding,
|
SingleDeviceSharding as SingleDeviceSharding,
|
||||||
PmapSharding as PmapSharding,
|
PmapSharding as PmapSharding,
|
||||||
OpShardingSharding as OpShardingSharding,
|
OpShardingSharding as OpShardingSharding,
|
||||||
PositionalSharding as PositionalSharding,
|
PositionalSharding as PositionalSharding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from jax.interpreters.pxla import Mesh as Mesh
|
||||||
|
@ -36,9 +36,8 @@ from jax import stages
|
|||||||
from jax.errors import JAXTypeError
|
from jax.errors import JAXTypeError
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import prng
|
from jax import prng
|
||||||
# TODO(skye): do we still wanna call this PartitionSpec?
|
from jax.sharding import PartitionSpec as P
|
||||||
from jax.experimental import maps
|
from jax.experimental import maps
|
||||||
from jax.experimental import PartitionSpec as P
|
|
||||||
from jax.experimental.maps import xmap
|
from jax.experimental.maps import xmap
|
||||||
from jax.experimental import global_device_array
|
from jax.experimental import global_device_array
|
||||||
from jax.experimental import multihost_utils
|
from jax.experimental import multihost_utils
|
||||||
@ -328,7 +327,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
def testTwoMeshAxisSharding(self):
|
def testTwoMeshAxisSharding(self):
|
||||||
@partial(pjit,
|
@partial(pjit,
|
||||||
in_axis_resources=P(('x', 'y'),),
|
in_axis_resources=P(('x', 'y'),),
|
||||||
out_axis_resources=P(('x', 'y'),))
|
out_axis_resources=jax.sharding.PartitionSpec(('x', 'y'),))
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return x @ y
|
return x @ y
|
||||||
|
|
||||||
@ -2210,7 +2209,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
[devices[4], devices[6]],
|
[devices[4], devices[6]],
|
||||||
[devices[7], devices[5]]])
|
[devices[7], devices[5]]])
|
||||||
shape = (8, 2)
|
shape = (8, 2)
|
||||||
mesh = maps.Mesh(mesh_devices, ('x', 'y'))
|
mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
|
||||||
s = NamedSharding(mesh, P('x', 'y'))
|
s = NamedSharding(mesh, P('x', 'y'))
|
||||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user