mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module. PiperOrigin-RevId: 522343010
This commit is contained in:
parent
492b9c1455
commit
452f3c55e3
@ -26,8 +26,8 @@ from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src import op_shardings
|
||||
from jax._src.pjit import pjit_check_aval_sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental import multihost_utils
|
||||
@ -579,7 +579,7 @@ def bench_are_op_shardings_equal(state):
|
||||
op2.tile_assignment_devices = list(range(12288))
|
||||
|
||||
while state:
|
||||
pxla.are_op_shardings_equal(op1, op2)
|
||||
op_shardings.are_op_shardings_equal(op1, op2)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
|
20
jax/BUILD
20
jax/BUILD
@ -176,12 +176,12 @@ py_library_providing_imports_info(
|
||||
":mesh",
|
||||
":mlir",
|
||||
":monitoring",
|
||||
":op_shardings",
|
||||
":partial_eval",
|
||||
":path",
|
||||
":pretty_printer",
|
||||
":profiler",
|
||||
":sharding",
|
||||
":sharding_utils",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
@ -341,8 +341,8 @@ pytype_strict_library(
|
||||
":config",
|
||||
":core",
|
||||
":effects",
|
||||
":op_shardings",
|
||||
":partial_eval",
|
||||
":sharding_utils",
|
||||
":source_info_util",
|
||||
":util",
|
||||
":xla",
|
||||
@ -356,6 +356,14 @@ pytype_strict_library(
|
||||
srcs = ["_src/monitoring.py"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "op_shardings",
|
||||
srcs = ["_src/op_shardings.py"],
|
||||
deps = [
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "partial_eval",
|
||||
srcs = ["_src/interpreters/partial_eval.py"],
|
||||
@ -403,14 +411,6 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sharding_utils",
|
||||
srcs = ["_src/sharding_utils.py"],
|
||||
deps = [
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "source_info_util",
|
||||
srcs = ["_src/source_info_util.py"],
|
||||
|
@ -34,10 +34,10 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects as effects_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src import sharding_utils as sutils
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
@ -556,11 +556,11 @@ def sharded_aval(aval: core.AbstractValue,
|
||||
if not isinstance(aval, core.ShapedArray):
|
||||
raise NotImplementedError
|
||||
|
||||
if (sutils.is_op_sharding_replicated(sharding) or
|
||||
if (op_shardings.is_op_sharding_replicated(sharding) or
|
||||
sharding.type == xc.OpSharding.Type.MANUAL):
|
||||
return aval
|
||||
|
||||
partitions, _ = sutils.get_num_ways_dim_sharded(sharding)
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding)
|
||||
out = []
|
||||
for s, p in zip(aval.shape, partitions):
|
||||
quotient, remainder = divmod(s, p)
|
||||
|
@ -56,12 +56,12 @@ from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh
|
||||
from jax._src import op_shardings
|
||||
from jax._src import profiler
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import util
|
||||
from jax._src import sharding_utils as sutils
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
@ -208,48 +208,6 @@ def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType
|
||||
return proto
|
||||
|
||||
|
||||
def _op_sharding_to_numpy_indices(
|
||||
op_sharding: xc.OpSharding, shape: Sequence[int],
|
||||
num_devices: int) -> np.ndarray:
|
||||
indices = np.empty(num_devices, dtype=np.object_)
|
||||
|
||||
# num_devices is required as an argument when op_sharding is
|
||||
# REPLICATED. `jax.device_count()` cannot be used because you can create
|
||||
# an opsharding with less number of devices than `jax.device_count()`.
|
||||
if sutils.is_op_sharding_replicated(op_sharding):
|
||||
indices.fill((slice(None),) * len(shape))
|
||||
return indices
|
||||
|
||||
assert num_devices == len(op_sharding.tile_assignment_devices)
|
||||
|
||||
partitions, num_replicas = sutils.get_num_ways_dim_sharded(op_sharding)
|
||||
assert len(partitions) == len(shape), (len(partitions), len(shape))
|
||||
|
||||
axis_indices: List[Sequence[Index]] = []
|
||||
for dim, n_shards in zip(shape, partitions):
|
||||
if n_shards == 1:
|
||||
axis_indices.append([slice(None)])
|
||||
elif n_shards > 1:
|
||||
shard_size, ragged = divmod(dim, n_shards)
|
||||
assert not ragged, (dim, n_shards)
|
||||
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||||
for i in range(n_shards)])
|
||||
else:
|
||||
raise AssertionError('Unrecognized number of shards. Please file a bug!')
|
||||
|
||||
device_it = iter(op_sharding.tile_assignment_devices)
|
||||
for i, idxs in enumerate(it.product(*axis_indices)):
|
||||
for _ in range(num_replicas):
|
||||
indices[next(device_it)] = idxs
|
||||
return indices
|
||||
|
||||
|
||||
def op_sharding_to_indices(op_sharding: xc.OpSharding, shape: Sequence[int],
|
||||
num_devices: int) -> Tuple[Tuple[slice, ...], ...]:
|
||||
indices = _op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
|
||||
return tuple(indices.flat)
|
||||
|
||||
|
||||
def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
"""Returns NumPy-style indices corresponding to a sharding spec.
|
||||
|
||||
@ -268,7 +226,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
# Take the op sharding indices generation route for pjit/xmap cases.
|
||||
if not has_unstacked:
|
||||
op_sharding_proto = sharding_spec_sharding_proto(self)
|
||||
return _op_sharding_to_numpy_indices(
|
||||
return op_shardings.op_sharding_to_numpy_indices(
|
||||
op_sharding_proto, shape, math.prod(self.mesh_shape)
|
||||
).reshape(self.mesh_shape)
|
||||
|
||||
@ -2778,7 +2736,7 @@ def _get_input_indices(
|
||||
# represent index for each device in the global mesh. But here we want
|
||||
# indices for the local devices of the global mesh.
|
||||
proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
if sutils.is_op_sharding_replicated(proto):
|
||||
if op_shardings.is_op_sharding_replicated(proto):
|
||||
index = tuple(
|
||||
(slice(None),) * aval.ndim
|
||||
for _ in range(len(sharding.addressable_devices))) # type: ignore
|
||||
@ -2986,7 +2944,7 @@ class UnloadedMeshExecutable:
|
||||
out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
if not are_op_shardings_equal(
|
||||
if not op_shardings.are_op_shardings_equal(
|
||||
xla_s._to_xla_op_sharding(aval.ndim), # type: ignore
|
||||
orig._to_xla_op_sharding(aval.ndim)): # type: ignore
|
||||
raise AssertionError(
|
||||
@ -3274,7 +3232,8 @@ def check_gda_or_array_xla_sharding_match(
|
||||
# for AOT compiled call.
|
||||
if (not check_device_backend_on_shardings([xs]) and
|
||||
arg._committed and
|
||||
not are_op_shardings_equal(arg.sharding._to_xla_op_sharding(arg.ndim),
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
arg.sharding._to_xla_op_sharding(arg.ndim),
|
||||
xs._to_xla_op_sharding(arg.ndim))):
|
||||
raise ValueError(
|
||||
f"Array sharding does not match the input sharding. "
|
||||
@ -3290,14 +3249,6 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
return _get_array_mapping(parsed_pspec)
|
||||
|
||||
|
||||
def are_op_shardings_equal(op1: xc.OpSharding, op2: xc.OpSharding) -> bool:
|
||||
if id(op1) == id(op2):
|
||||
return True
|
||||
if sutils.is_op_sharding_replicated(op1) and sutils.is_op_sharding_replicated(op2):
|
||||
return True
|
||||
return xc.HloSharding.from_proto(op1) == xc.HloSharding.from_proto(op2)
|
||||
|
||||
|
||||
_forbidden_primitives = {
|
||||
'xla_pmap': 'pmap',
|
||||
'sharded_call': 'sharded_jit',
|
||||
|
@ -28,6 +28,7 @@ from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import mesh
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
@ -1731,7 +1732,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
args_flat):
|
||||
@lru_cache()
|
||||
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
|
||||
if not pxla.are_op_shardings_equal(
|
||||
if not op_shardings.are_op_shardings_equal(
|
||||
in_sharding._to_xla_op_sharding(ndim),
|
||||
xmap_sharding._to_xla_op_sharding(ndim)):
|
||||
raise ValueError(
|
||||
|
95
jax/_src/op_shardings.py
Normal file
95
jax/_src/op_shardings.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Sharding utilities"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
def get_num_ways_dim_sharded(
|
||||
op_sharding: xc.OpSharding) -> Tuple[Sequence[int], int]:
|
||||
partitions = op_sharding.tile_assignment_dimensions
|
||||
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
||||
replicate_on_last_tile_dim = True
|
||||
else:
|
||||
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
|
||||
if op_sharding.last_tile_dims:
|
||||
raise NotImplementedError(
|
||||
"Unhandled OpSharding type. Please open a bug report!")
|
||||
num_replicas = 1
|
||||
if replicate_on_last_tile_dim:
|
||||
num_replicas = partitions[-1]
|
||||
partitions = partitions[:-1]
|
||||
return partitions, num_replicas
|
||||
|
||||
|
||||
def is_op_sharding_replicated(op: xc.OpSharding) -> bool:
|
||||
if len(op.tile_assignment_devices) == 1:
|
||||
return True
|
||||
return xc.HloSharding.from_proto(op).is_replicated() # type: ignore
|
||||
|
||||
def are_op_shardings_equal(op1: xc.OpSharding, op2: xc.OpSharding) -> bool:
|
||||
if id(op1) == id(op2):
|
||||
return True
|
||||
if is_op_sharding_replicated(op1) and is_op_sharding_replicated(op2):
|
||||
return True
|
||||
return xc.HloSharding.from_proto(op1) == xc.HloSharding.from_proto(op2)
|
||||
|
||||
_Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
||||
|
||||
def op_sharding_to_numpy_indices(
|
||||
op_sharding: xc.OpSharding, shape: Sequence[int],
|
||||
num_devices: int) -> np.ndarray:
|
||||
indices = np.empty(num_devices, dtype=np.object_)
|
||||
|
||||
# num_devices is required as an argument when op_sharding is
|
||||
# REPLICATED. `jax.device_count()` cannot be used because you can create
|
||||
# an opsharding with less number of devices than `jax.device_count()`.
|
||||
if is_op_sharding_replicated(op_sharding):
|
||||
indices.fill((slice(None),) * len(shape))
|
||||
return indices
|
||||
|
||||
assert num_devices == len(op_sharding.tile_assignment_devices)
|
||||
|
||||
partitions, num_replicas = get_num_ways_dim_sharded(op_sharding)
|
||||
assert len(partitions) == len(shape), (len(partitions), len(shape))
|
||||
|
||||
axis_indices: List[Sequence[_Index]] = []
|
||||
for dim, n_shards in zip(shape, partitions):
|
||||
if n_shards == 1:
|
||||
axis_indices.append([slice(None)])
|
||||
elif n_shards > 1:
|
||||
shard_size, ragged = divmod(dim, n_shards)
|
||||
assert not ragged, (dim, n_shards)
|
||||
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||||
for i in range(n_shards)])
|
||||
else:
|
||||
raise AssertionError('Unrecognized number of shards. Please file a bug!')
|
||||
|
||||
device_it = iter(op_sharding.tile_assignment_devices)
|
||||
for i, idxs in enumerate(itertools.product(*axis_indices)):
|
||||
for _ in range(num_replicas):
|
||||
indices[next(device_it)] = idxs
|
||||
return indices
|
||||
|
||||
|
||||
def op_sharding_to_indices(op_sharding: xc.OpSharding, shape: Sequence[int],
|
||||
num_devices: int) -> Tuple[Tuple[slice, ...], ...]:
|
||||
indices = op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
|
||||
return tuple(indices.flat)
|
@ -29,8 +29,8 @@ from jax._src import stages
|
||||
from jax._src import dispatch
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import source_info_util
|
||||
from jax._src import sharding_utils as sutils
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -962,7 +962,7 @@ def pjit_check_aval_sharding(
|
||||
# XLACompatibleSharding.
|
||||
op_sharding = s._to_xla_op_sharding(len(shape))
|
||||
assert op_sharding is not None
|
||||
num_ways_dim_sharded, _ = sutils.get_num_ways_dim_sharded(
|
||||
num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(
|
||||
cast(xc.OpSharding, op_sharding))
|
||||
for i, size in enumerate(num_ways_dim_sharded):
|
||||
if not allow_uneven_sharding and shape[i] % size != 0:
|
||||
@ -1201,8 +1201,10 @@ def _resolve_in_shardings(
|
||||
raise NotImplementedError('Having uncommitted Array sharded on '
|
||||
'multiple devices is not supported.')
|
||||
else:
|
||||
if isinstance(arg, np.ndarray) and not sutils.is_op_sharding_replicated(
|
||||
pjit_in_s._to_xla_op_sharding(arg.ndim)) and xb.process_count() > 1: # type: ignore
|
||||
if (isinstance(arg, np.ndarray) and
|
||||
not op_shardings.is_op_sharding_replicated(
|
||||
pjit_in_s._to_xla_op_sharding(arg.ndim)) # type: ignore
|
||||
and xb.process_count() > 1):
|
||||
raise ValueError(
|
||||
'Passing non-trivial shardings for numpy '
|
||||
'inputs is not allowed. To fix this error, either specify a '
|
||||
@ -1219,7 +1221,7 @@ def _resolve_in_shardings(
|
||||
if not _is_unspecified(arg_s):
|
||||
if (committed and
|
||||
not isinstance(arg_s, PmapSharding) and
|
||||
not pxla.are_op_shardings_equal(
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
pjit_in_s._to_xla_op_sharding(arg.ndim), # type: ignore
|
||||
arg_s._to_xla_op_sharding(arg.ndim))):
|
||||
op = getattr(pjit_in_s, '_original_sharding', pjit_in_s)
|
||||
@ -1314,11 +1316,15 @@ class SameDeviceAssignmentTuple:
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, SameDeviceAssignmentTuple):
|
||||
return False
|
||||
return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
|
||||
return (
|
||||
all(
|
||||
op_shardings.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
|
||||
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)
|
||||
else s == o
|
||||
for s, o in zip(self.shardings, other.shardings)) and
|
||||
self.device_assignment == other.device_assignment)
|
||||
for s, o in zip(self.shardings, other.shardings)
|
||||
)
|
||||
and self.device_assignment == other.device_assignment
|
||||
)
|
||||
|
||||
|
||||
def _pjit_lower(
|
||||
|
@ -22,8 +22,8 @@ from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_utils as sutils
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -74,9 +74,9 @@ class XLACompatibleSharding(sharding.Sharding):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
|
||||
if sutils.is_op_sharding_replicated(op_sharding):
|
||||
if op_shardings.is_op_sharding_replicated(op_sharding):
|
||||
return global_shape
|
||||
partitions, _ = sutils.get_num_ways_dim_sharded(op_sharding)
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(op_sharding)
|
||||
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
|
||||
out = []
|
||||
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
|
||||
@ -94,9 +94,12 @@ class XLACompatibleSharding(sharding.Sharding):
|
||||
def is_equivalent_to(self: XLACompatibleSharding, # type: ignore
|
||||
other: XLACompatibleSharding, ndim: int) -> bool:
|
||||
try:
|
||||
return (pxla.are_op_shardings_equal(self._to_xla_op_sharding(ndim),
|
||||
other._to_xla_op_sharding(ndim)) and
|
||||
self._device_assignment == other._device_assignment)
|
||||
return (
|
||||
op_shardings.are_op_shardings_equal(
|
||||
self._to_xla_op_sharding(ndim), other._to_xla_op_sharding(ndim)
|
||||
)
|
||||
and self._device_assignment == other._device_assignment
|
||||
)
|
||||
# NotImplementedError is raised by PmapSharding because it can't lower
|
||||
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
||||
# equality check.
|
||||
@ -611,8 +614,12 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and
|
||||
self._devices == other._devices)
|
||||
return (
|
||||
op_shardings.are_op_shardings_equal(
|
||||
self._op_sharding, other._op_sharding
|
||||
)
|
||||
and self._devices == other._devices
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
@ -623,7 +630,8 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
return f'GSPMDSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
num_ways_dim_sharded, _ = sutils.get_num_ways_dim_sharded(self._op_sharding)
|
||||
num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(
|
||||
self._op_sharding)
|
||||
if len(aval_shape) < len(num_ways_dim_sharded):
|
||||
raise ValueError(
|
||||
f"Sharding {self} is only valid for values of rank at least "
|
||||
@ -637,8 +645,8 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
self.shard_shape(global_shape) # raises a good error message
|
||||
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
||||
len(self._devices))
|
||||
indices = op_shardings.op_sharding_to_indices(
|
||||
self._op_sharding, global_shape, len(self._devices))
|
||||
return dict(safe_zip(self._devices, indices))
|
||||
|
||||
@property
|
||||
|
@ -1,40 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Sharding utilities"""
|
||||
|
||||
from typing import Sequence, Tuple
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
def get_num_ways_dim_sharded(
|
||||
op_sharding: xc.OpSharding) -> Tuple[Sequence[int], int]:
|
||||
partitions = op_sharding.tile_assignment_dimensions
|
||||
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
||||
replicate_on_last_tile_dim = True
|
||||
else:
|
||||
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
|
||||
if op_sharding.last_tile_dims:
|
||||
raise NotImplementedError(
|
||||
"Unhandled OpSharding type. Please open a bug report!")
|
||||
num_replicas = 1
|
||||
if replicate_on_last_tile_dim:
|
||||
num_replicas = partitions[-1]
|
||||
partitions = partitions[:-1]
|
||||
return partitions, num_replicas
|
||||
|
||||
|
||||
def is_op_sharding_replicated(op: xc.OpSharding) -> bool:
|
||||
if len(op.tile_assignment_devices) == 1:
|
||||
return True
|
||||
return xc.HloSharding.from_proto(op).is_replicated() # type: ignore
|
@ -48,9 +48,9 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import sharding_utils as sutils
|
||||
from jax._src import random as random_internal
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -3143,7 +3143,8 @@ def _shard_value(val: TfVal,
|
||||
sharding_proto: xla_client.OpSharding = cast(
|
||||
xla_client.OpSharding, sd._to_xla_op_sharding(aval.ndim))
|
||||
|
||||
if skip_replicated_sharding and sutils.is_op_sharding_replicated(sharding_proto):
|
||||
if (skip_replicated_sharding and
|
||||
op_shardings.is_op_sharding_replicated(sharding_proto)):
|
||||
return val
|
||||
|
||||
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
|
||||
|
@ -55,7 +55,6 @@ from jax._src.interpreters.pxla import (
|
||||
_get_and_check_device_assignment as _get_and_check_device_assignment,
|
||||
_is_unspecified as _is_unspecified,
|
||||
_pmap_sharding_spec as _pmap_sharding_spec,
|
||||
are_op_shardings_equal as are_op_shardings_equal,
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
array_types as array_types,
|
||||
custom_resource_typing_rules as custom_resource_typing_rules,
|
||||
@ -78,7 +77,6 @@ from jax._src.interpreters.pxla import (
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
multi_host_supported_collectives as multi_host_supported_collectives,
|
||||
new_mesh_sharding_specs as new_mesh_sharding_specs,
|
||||
op_sharding_to_indices as op_sharding_to_indices,
|
||||
parallel_callable as parallel_callable,
|
||||
partitioned_sharding_spec as partitioned_sharding_spec,
|
||||
reconcile_num_partitions as reconcile_num_partitions,
|
||||
@ -110,8 +108,10 @@ from jax._src.mesh import (
|
||||
thread_resources as thread_resources,
|
||||
)
|
||||
|
||||
from jax._src.sharding_utils import (
|
||||
is_op_sharding_replicated as is_op_sharding_replicated
|
||||
from jax._src.op_shardings import (
|
||||
are_op_shardings_equal as are_op_shardings_equal,
|
||||
is_op_sharding_replicated as is_op_sharding_replicated,
|
||||
op_sharding_to_indices as op_sharding_to_indices,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
@ -24,11 +24,11 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import op_shardings
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.util import safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -456,7 +456,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(i, j)
|
||||
self.assertLen(i.sharding.device_set, 8)
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
arr.sharding._to_xla_op_sharding(arr.ndim),
|
||||
i.sharding._to_xla_op_sharding(i.ndim)))
|
||||
|
||||
@ -514,7 +514,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s, np.array([[4], [6]]))
|
||||
self.assertLen(s.sharding.device_set, 8)
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
arr.sharding._to_xla_op_sharding(arr.ndim),
|
||||
s.sharding._to_xla_op_sharding(s.ndim)))
|
||||
|
||||
@ -523,7 +523,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(p, input_data[:2])
|
||||
self.assertLen(s.sharding.device_set, 8)
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
arr.sharding._to_xla_op_sharding(arr.ndim),
|
||||
s.sharding._to_xla_op_sharding(s.ndim)))
|
||||
|
||||
@ -603,11 +603,11 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(input_shardings[1], {})
|
||||
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
input_shardings[0][0]._to_xla_op_sharding(x_dummy.ndim),
|
||||
s._to_xla_op_sharding(x_dummy.ndim)))
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
output_shardings._to_xla_op_sharding(x_dummy.ndim),
|
||||
s._to_xla_op_sharding(x_dummy.ndim)))
|
||||
|
||||
@ -626,11 +626,11 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
c = pjit(f).lower(x_dummy).compile()
|
||||
input_shardings, output_shardings = c.input_shardings, c.output_shardings
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
input_shardings[0][0]._to_xla_op_sharding(x_dummy.ndim),
|
||||
s._to_xla_op_sharding(x_dummy.ndim)))
|
||||
self.assertTrue(
|
||||
pxla.are_op_shardings_equal(
|
||||
op_shardings.are_op_shardings_equal(
|
||||
output_shardings._to_xla_op_sharding(x_dummy.ndim),
|
||||
s._to_xla_op_sharding(x_dummy.ndim)))
|
||||
|
||||
@ -837,7 +837,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(mps.shard_shape(value_shape),
|
||||
devices_sharding.shard_shape(value_shape))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
|
||||
def test_devices_sharding_respects_init_mesh_shape(self):
|
||||
value_shape = (8, 4)
|
||||
@ -852,7 +852,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(mps.shard_shape(value_shape),
|
||||
devices_sharding.shard_shape(value_shape))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
|
||||
def test_pmap_sharding_repr(self):
|
||||
if jax.device_count() < 2:
|
||||
|
@ -43,7 +43,7 @@ from jax.experimental import multihost_utils
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src import sharding_utils as sutils
|
||||
from jax._src import op_shardings
|
||||
from jax._src.sharding_impls import NamedSharding, GSPMDSharding
|
||||
import jax._src.pjit as pjit_lib
|
||||
from jax._src.pjit import (pjit, pjit_p, AUTO)
|
||||
@ -665,7 +665,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(op.type, xc.OpSharding.Type.OTHER)
|
||||
self.assertListEqual(op.tile_assignment_dimensions, [1, 2])
|
||||
self.assertListEqual(op.tile_assignment_devices, [0, 1])
|
||||
self.assertFalse(sutils.is_op_sharding_replicated(op))
|
||||
self.assertFalse(op_shardings.is_op_sharding_replicated(op))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testVMapShardingConstraintWithSpmdAxis(self):
|
||||
@ -685,7 +685,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(op.type, xc.OpSharding.Type.OTHER)
|
||||
self.assertListEqual(op.tile_assignment_dimensions, [2, 1])
|
||||
self.assertListEqual(op.tile_assignment_devices, [0, 1])
|
||||
self.assertFalse(sutils.is_op_sharding_replicated(op))
|
||||
self.assertFalse(op_shardings.is_op_sharding_replicated(op))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testShardingInXMap(self):
|
||||
@ -702,7 +702,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertLen(in_shardings, 1)
|
||||
self.assertListEqual(in_shardings[0]._op_sharding.tile_assignment_dimensions,
|
||||
[1, 1, 2])
|
||||
self.assertFalse(sutils.is_op_sharding_replicated(in_shardings[0]._op_sharding))
|
||||
self.assertFalse(op_shardings.is_op_sharding_replicated(in_shardings[0]._op_sharding))
|
||||
|
||||
return rule(*args, **kwargs)
|
||||
try:
|
||||
@ -1956,7 +1956,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
out = sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
out_s = NamedSharding(mesh, P('x', 'y'))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(
|
||||
out.sharding._to_xla_op_sharding(out.ndim),
|
||||
out_s._to_xla_op_sharding(out.ndim)))
|
||||
|
||||
@ -1970,7 +1970,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
out = sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
out_s = NamedSharding(mesh, P('x', 'y'))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(
|
||||
out.sharding._to_xla_op_sharding(out.ndim),
|
||||
out_s._to_xla_op_sharding(out.ndim)))
|
||||
|
||||
@ -2563,9 +2563,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out2, inp2 * 2)
|
||||
self.assertLen(out1.devices(), 4)
|
||||
self.assertLen(out2.devices(), 4)
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(
|
||||
out1.sharding._to_xla_op_sharding(pmap_out.ndim)))
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(
|
||||
out2.sharding._to_xla_op_sharding(inp2.ndim)))
|
||||
|
||||
def test_pmap_sharding_input_pjit_in_axis_resources(self):
|
||||
@ -2771,7 +2771,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with mesh:
|
||||
out = jax.vmap(jax.jit(f), spmd_axis_name='mdl')(x)
|
||||
ns, _ = sutils.get_num_ways_dim_sharded(
|
||||
ns, _ = op_shardings.get_num_ways_dim_sharded(
|
||||
out.sharding._to_xla_op_sharding(out.ndim))
|
||||
self.assertListEqual(ns, [2, 2, 1, 1])
|
||||
|
||||
@ -2781,7 +2781,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with mesh:
|
||||
out2 = jax.vmap(apply_with_scan, spmd_axis_name='mdl')(x)
|
||||
ns2, _ = sutils.get_num_ways_dim_sharded(
|
||||
ns2, _ = op_shardings.get_num_ways_dim_sharded(
|
||||
out2.sharding._to_xla_op_sharding(out2.ndim))
|
||||
self.assertListEqual(ns2, [2, 2, 1, 1])
|
||||
|
||||
@ -3355,9 +3355,9 @@ class UtilTest(jtu.JaxTestCase):
|
||||
op3.tile_assignment_dimensions = [4, 2]
|
||||
op3.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertFalse(pxla.are_op_shardings_equal(op1, op3))
|
||||
self.assertFalse(pxla.are_op_shardings_equal(op2, op3))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
self.assertFalse(op_shardings.are_op_shardings_equal(op1, op3))
|
||||
self.assertFalse(op_shardings.are_op_shardings_equal(op2, op3))
|
||||
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
@ -3380,7 +3380,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
op2.tile_assignment_devices = [0, 1, 2, 3]
|
||||
op2.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
||||
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
@ -3407,7 +3407,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
op2.type = xc.OpSharding.Type.TUPLE
|
||||
op2.tuple_shardings = [top2, top1]
|
||||
|
||||
self.assertFalse(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertFalse(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
@ -3465,13 +3465,13 @@ class UtilTest(jtu.JaxTestCase):
|
||||
op4.tile_assignment_dimensions = [1]
|
||||
op4.tile_assignment_devices = [0]
|
||||
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(op1))
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(op2))
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(op3))
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(op4))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op2, op3))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op3, op4))
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(op1))
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(op2))
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(op3))
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(op4))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op2, op3))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op3, op4))
|
||||
|
||||
def test_op_sharding_manual_replicated(self):
|
||||
op1 = xc.OpSharding()
|
||||
@ -3489,10 +3489,10 @@ class UtilTest(jtu.JaxTestCase):
|
||||
op3 = xc.OpSharding()
|
||||
op3.type = xc.OpSharding.Type.REPLICATED
|
||||
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(op1))
|
||||
self.assertTrue(sutils.is_op_sharding_replicated(op2))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op3))
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(op1))
|
||||
self.assertTrue(op_shardings.is_op_sharding_replicated(op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3))
|
||||
|
||||
def test_op_sharding_cache_on_mesh_pspec_sharding(self):
|
||||
ndim = 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user