2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2021 The JAX Authors.
|
2022-06-06 17:31:20 -07:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2023-07-20 09:43:40 -07:00
|
|
|
|
from collections import defaultdict
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from collections.abc import Callable, Sequence
|
2023-09-08 09:18:38 -04:00
|
|
|
|
import enum
|
2024-05-15 22:06:11 -07:00
|
|
|
|
import functools
|
2023-02-28 12:40:30 -08:00
|
|
|
|
import math
|
2022-08-17 12:25:14 -07:00
|
|
|
|
import operator as op
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from typing import Any, TYPE_CHECKING, cast
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src import api
|
2022-09-23 09:59:46 -07:00
|
|
|
|
from jax._src import basearray
|
2023-10-09 07:28:18 -07:00
|
|
|
|
from jax._src import config
|
2022-12-16 20:59:41 -08:00
|
|
|
|
from jax._src import core
|
2024-09-13 08:37:32 -07:00
|
|
|
|
from jax._src import deprecations
|
2022-06-24 10:04:31 -07:00
|
|
|
|
from jax._src import dispatch
|
2022-08-16 16:51:26 -07:00
|
|
|
|
from jax._src import dtypes
|
2024-01-31 15:13:33 -08:00
|
|
|
|
from jax._src import errors
|
2023-03-15 08:41:47 -07:00
|
|
|
|
from jax._src import profiler
|
2024-12-11 16:54:52 -05:00
|
|
|
|
from jax._src import util
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src import xla_bridge
|
2025-02-18 15:22:06 -08:00
|
|
|
|
from jax._src.mesh import set_concrete_mesh
|
2023-03-31 08:50:59 -07:00
|
|
|
|
from jax._src.interpreters import mlir
|
2023-02-07 11:16:01 -08:00
|
|
|
|
from jax._src.interpreters import pxla
|
2023-02-07 15:00:56 -08:00
|
|
|
|
from jax._src.interpreters import xla
|
2024-05-15 22:06:11 -07:00
|
|
|
|
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
|
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
|
from jax._src.lib import xla_extension as xe
|
2023-03-13 08:49:39 -07:00
|
|
|
|
from jax._src.sharding import Sharding
|
|
|
|
|
from jax._src.sharding_impls import (
|
2025-01-20 15:12:12 -08:00
|
|
|
|
PmapSharding, SingleDeviceSharding,
|
2024-06-06 12:40:21 -07:00
|
|
|
|
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
|
2025-02-26 16:56:47 -08:00
|
|
|
|
from jax._src.typing import ArrayLike, DLDeviceType
|
2024-06-11 12:46:11 -07:00
|
|
|
|
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
|
2024-05-15 22:06:11 -07:00
|
|
|
|
import numpy as np
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2024-02-09 11:18:19 -08:00
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
Shape = tuple[int, ...]
|
2022-06-06 17:31:20 -07:00
|
|
|
|
Device = xc.Device
|
2023-06-23 15:11:37 -07:00
|
|
|
|
Index = tuple[slice, ...]
|
2024-01-29 12:40:47 -08:00
|
|
|
|
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2023-11-29 16:52:09 -08:00
|
|
|
|
def _get_device(a: ArrayImpl) -> Device:
|
2024-05-22 06:35:38 -07:00
|
|
|
|
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
|
2025-02-05 09:23:38 -08:00
|
|
|
|
if len(devices) != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"When making an array from single-device arrays the input arrays must "
|
|
|
|
|
f"have one shard each. An argument array had {len(devices)} shard(s).")
|
2024-04-12 16:12:20 -07:00
|
|
|
|
return devices[0]
|
2023-11-29 16:52:09 -08:00
|
|
|
|
|
|
|
|
|
|
2022-06-06 18:44:45 -07:00
|
|
|
|
class Shard:
|
|
|
|
|
"""A single data shard of an Array.
|
|
|
|
|
|
2022-06-14 10:34:19 -07:00
|
|
|
|
Attributes:
|
2022-06-06 18:44:45 -07:00
|
|
|
|
device : Which device this shard resides on.
|
|
|
|
|
index : The index into the global array of this shard.
|
|
|
|
|
replica_id : Integer id indicating which replica of the global array this
|
|
|
|
|
shard is part of. Always 0 for fully sharded data
|
|
|
|
|
(i.e. when there’s only 1 replica).
|
|
|
|
|
data : The data of this shard. None if ``device`` is non-local.
|
|
|
|
|
"""
|
2022-06-14 10:34:19 -07:00
|
|
|
|
|
|
|
|
|
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
|
2024-01-29 12:40:47 -08:00
|
|
|
|
data: None | ArrayImpl | PRNGKeyArray = None):
|
2023-01-05 14:27:17 +00:00
|
|
|
|
self._device = device
|
2022-06-14 10:34:19 -07:00
|
|
|
|
self._sharding = sharding
|
|
|
|
|
self._global_shape = global_shape
|
2023-01-05 14:27:17 +00:00
|
|
|
|
self._data = data
|
2022-06-14 10:34:19 -07:00
|
|
|
|
|
2022-06-22 02:25:34 -07:00
|
|
|
|
def __repr__(self):
|
|
|
|
|
try:
|
2023-10-23 15:11:15 +01:00
|
|
|
|
return (f'Shard(device={self.device!r}, index={self.index}, '
|
2022-06-22 02:25:34 -07:00
|
|
|
|
f'replica_id={self.replica_id}, data={self.data})')
|
|
|
|
|
except ValueError:
|
2023-10-23 15:11:15 +01:00
|
|
|
|
return f'Shard(device={self.device!r}, data={self.data})'
|
2022-06-22 02:25:34 -07:00
|
|
|
|
|
2023-03-20 08:36:25 -07:00
|
|
|
|
@functools.cached_property
|
2022-06-14 10:34:19 -07:00
|
|
|
|
def index(self) -> Index:
|
2022-09-18 15:35:18 -07:00
|
|
|
|
try:
|
|
|
|
|
device_indices_map_fn = self._sharding.devices_indices_map
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError('Cannot calculate indices from sharding: '
|
|
|
|
|
f'{self._sharding}. Please create a device to index '
|
|
|
|
|
'mapping for your sharding.') from None
|
|
|
|
|
index = device_indices_map_fn(self._global_shape)[self.device]
|
2022-06-14 10:34:19 -07:00
|
|
|
|
assert index is not None
|
|
|
|
|
return index
|
|
|
|
|
|
2023-03-20 08:36:25 -07:00
|
|
|
|
@functools.cached_property
|
2022-06-14 10:34:19 -07:00
|
|
|
|
def replica_id(self) -> int:
|
2022-08-29 14:49:17 -07:00
|
|
|
|
return device_replica_id_map(self._sharding, self._global_shape)[self.device]
|
2022-06-06 18:44:45 -07:00
|
|
|
|
|
2023-01-05 14:27:17 +00:00
|
|
|
|
@property
|
|
|
|
|
def device(self):
|
|
|
|
|
return self._device
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def data(self):
|
|
|
|
|
return self._data
|
|
|
|
|
|
2022-06-06 18:44:45 -07:00
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def _reconstruct_array(fun, args, arr_state, aval_state):
|
|
|
|
|
"""Method to reconstruct a device array from a serialized state."""
|
|
|
|
|
np_value = fun(*args)
|
|
|
|
|
np_value.__setstate__(arr_state)
|
2022-11-30 15:25:21 -08:00
|
|
|
|
jnp_value = api.device_put(np_value)
|
2024-09-13 08:37:32 -07:00
|
|
|
|
# TODO(slebedev): Remove this branch after December 10th 2024.
|
|
|
|
|
if "named_shape" in aval_state:
|
|
|
|
|
deprecations.warn(
|
|
|
|
|
"jax-aval-named-shape",
|
|
|
|
|
"Pickled array contains an aval with a named_shape attribute. This is"
|
|
|
|
|
" deprecated and the code path supporting such avals will be removed."
|
|
|
|
|
" Please re-pickle the array.",
|
|
|
|
|
stacklevel=2,
|
|
|
|
|
)
|
|
|
|
|
del aval_state["named_shape"]
|
2022-08-18 15:58:40 -07:00
|
|
|
|
jnp_value.aval = jnp_value.aval.update(**aval_state)
|
|
|
|
|
return jnp_value
|
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
2023-03-24 13:21:20 -07:00
|
|
|
|
def _cached_index_calc(s, shape):
|
|
|
|
|
map_ = s.addressable_devices_indices_map(shape)
|
|
|
|
|
seen_h_indices = set()
|
2024-04-15 08:29:02 -07:00
|
|
|
|
l = []
|
|
|
|
|
for array_index, index in enumerate(map_.values()):
|
2023-03-24 13:21:20 -07:00
|
|
|
|
h_index = hashed_index(index)
|
|
|
|
|
if h_index not in seen_h_indices:
|
|
|
|
|
seen_h_indices.add(h_index)
|
2024-04-15 08:29:02 -07:00
|
|
|
|
l.append((array_index, index))
|
|
|
|
|
return l
|
2023-03-24 13:21:20 -07:00
|
|
|
|
|
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
2023-03-24 13:21:20 -07:00
|
|
|
|
def _process_has_full_value_in_mcjax(s, shape):
|
|
|
|
|
# Return False for single host as a fast path.
|
2023-04-04 11:41:00 -07:00
|
|
|
|
if xla_bridge.process_count() == 1:
|
2023-03-24 13:21:20 -07:00
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
num_unique_indices = len(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
{hashed_index(v) for v in s.devices_indices_map(shape).values()})
|
2023-03-24 13:21:20 -07:00
|
|
|
|
num_addressable_unique_indices = len(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
{hashed_index(v) for v in s.addressable_devices_indices_map(shape).values()})
|
2023-03-24 13:21:20 -07:00
|
|
|
|
return num_unique_indices == num_addressable_unique_indices
|
|
|
|
|
|
|
|
|
|
|
2024-06-27 20:59:25 -07:00
|
|
|
|
def _validate_shape_and_dtype_for_per_device_arrays(
|
|
|
|
|
arrays: Sequence[ArrayImpl | np.ndarray],
|
|
|
|
|
sharding: Sharding,
|
|
|
|
|
aval: core.ShapedArray,
|
|
|
|
|
expected_shape: Shape,
|
|
|
|
|
):
|
|
|
|
|
"""Validates that per-device arrays are valid and consistent."""
|
|
|
|
|
expected_dtype = aval.dtype
|
|
|
|
|
for db in arrays:
|
|
|
|
|
if db.dtype != expected_dtype:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input buffers to `Array` must have matching dtypes. "
|
|
|
|
|
f"Got {db.dtype}, expected {expected_dtype} for buffer: {db}"
|
|
|
|
|
)
|
|
|
|
|
if db.shape != expected_shape:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected shard shape {expected_shape} doesn't match the single "
|
|
|
|
|
f"device array shape {db.shape}. Shape of Array is "
|
|
|
|
|
f"{aval.str_short()} with sharding {sharding}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
class ArrayImpl(basearray.Array):
|
2022-06-06 17:31:20 -07:00
|
|
|
|
# TODO(yashkatariya): Add __slots__ here.
|
|
|
|
|
|
2022-09-21 12:51:32 -07:00
|
|
|
|
aval: core.ShapedArray
|
|
|
|
|
_sharding: Sharding
|
2023-06-23 15:11:37 -07:00
|
|
|
|
_arrays: list[ArrayImpl]
|
2022-09-21 12:51:32 -07:00
|
|
|
|
_committed: bool
|
|
|
|
|
_skip_checks: bool
|
2023-07-21 14:20:39 -04:00
|
|
|
|
_npy_value: np.ndarray | None
|
2022-09-21 12:51:32 -07:00
|
|
|
|
|
2023-02-17 11:52:08 -08:00
|
|
|
|
@use_cpp_method()
|
2022-08-17 12:25:14 -07:00
|
|
|
|
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
|
2023-03-29 12:58:34 -07:00
|
|
|
|
arrays: Sequence[ArrayImpl],
|
2022-09-18 15:35:18 -07:00
|
|
|
|
committed: bool, _skip_checks: bool = False):
|
2022-09-08 13:47:57 -07:00
|
|
|
|
# NOTE: the actual implementation of the constructor is moved to C++.
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
self.aval = aval
|
2022-06-06 17:31:20 -07:00
|
|
|
|
self._sharding = sharding
|
|
|
|
|
self._committed = committed
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._npy_value = None
|
2025-02-05 09:23:38 -08:00
|
|
|
|
arrays = [a._arrays[0] for a in arrays]
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-08-23 10:19:59 -07:00
|
|
|
|
# Don't rearrange if skip_checks is enabled because this assumes that the
|
|
|
|
|
# input buffers are already arranged properly. This usually happens when
|
|
|
|
|
# Array's are created as output of a JAX transformation
|
2024-07-24 10:23:29 -07:00
|
|
|
|
# (like pjit, etc).
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if not _skip_checks or config.enable_checks.value:
|
2025-02-05 09:23:38 -08:00
|
|
|
|
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
|
2025-02-13 18:05:27 +00:00
|
|
|
|
self._arrays = arrays
|
2025-02-05 09:23:38 -08:00
|
|
|
|
|
2025-02-24 17:45:19 -05:00
|
|
|
|
def _check_and_rearrange(self, arrays, sharding, aval):
|
|
|
|
|
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
|
2025-01-31 14:41:45 -08:00
|
|
|
|
|
2025-02-24 17:45:19 -05:00
|
|
|
|
addressable_dev = sharding.addressable_devices
|
|
|
|
|
if len(arrays) != len(addressable_dev):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected {len(addressable_dev)} per-device arrays "
|
|
|
|
|
"(this is how many devices are addressable by the sharding), but "
|
|
|
|
|
f"got {len(arrays)}")
|
|
|
|
|
|
|
|
|
|
array_device_ids = set(device_id_to_buffer.keys())
|
|
|
|
|
addressable_device_ids = {d.id for d in addressable_dev}
|
|
|
|
|
if len(array_device_ids) != len(arrays):
|
|
|
|
|
buffer_device_ids = [_get_device(db).id for db in arrays]
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"When making an array from single-device arrays, the input arrays"
|
|
|
|
|
" must be from distinct devices, but got device IDs"
|
|
|
|
|
f" {buffer_device_ids}")
|
|
|
|
|
|
|
|
|
|
# Calculate a symmetric difference because the device ids between sharding
|
|
|
|
|
# and _arrays should match.
|
|
|
|
|
diff = array_device_ids ^ addressable_device_ids
|
|
|
|
|
if diff:
|
|
|
|
|
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
|
|
|
|
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
|
|
|
|
err_msg = (
|
|
|
|
|
"Addressable devices and per-device arrays devices do not match.")
|
|
|
|
|
if dev_in_sharding_not_in_arrays:
|
|
|
|
|
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
|
|
|
|
"that are not present in per-device arrays.")
|
|
|
|
|
if dev_in_arrays_not_in_sharding:
|
|
|
|
|
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
|
|
|
|
"that are not present in the sharding.")
|
|
|
|
|
raise ValueError(err_msg)
|
|
|
|
|
|
|
|
|
|
_validate_shape_and_dtype_for_per_device_arrays(
|
|
|
|
|
arrays,
|
|
|
|
|
sharding=sharding,
|
|
|
|
|
aval=aval,
|
|
|
|
|
expected_shape=sharding.shard_shape(aval.shape),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Rearrange arrays based on the device assignment.
|
|
|
|
|
addressable_da = sharding._addressable_device_assignment
|
|
|
|
|
return [device_id_to_buffer[device.id] for device in addressable_da]
|
2022-09-08 13:47:57 -07:00
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
@property
|
|
|
|
|
def shape(self) -> Shape:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return self.aval.shape
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
@property
|
2022-08-17 12:25:14 -07:00
|
|
|
|
def dtype(self):
|
|
|
|
|
return self.aval.dtype
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
@property
|
|
|
|
|
def ndim(self):
|
|
|
|
|
return len(self.shape)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def size(self):
|
2023-02-28 12:40:30 -08:00
|
|
|
|
return math.prod(self.shape)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def sharding(self):
|
|
|
|
|
return self._sharding
|
|
|
|
|
|
2024-07-23 09:48:51 -07:00
|
|
|
|
@property
|
|
|
|
|
def device(self):
|
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
if isinstance(self.sharding, SingleDeviceSharding):
|
|
|
|
|
return list(self.sharding.device_set)[0]
|
|
|
|
|
return self.sharding
|
|
|
|
|
|
2022-10-10 18:10:46 -07:00
|
|
|
|
@property
|
|
|
|
|
def weak_type(self):
|
|
|
|
|
return self.aval.weak_type
|
|
|
|
|
|
2024-10-15 19:45:25 -07:00
|
|
|
|
@property
|
|
|
|
|
def committed(self) -> bool:
|
|
|
|
|
return self._committed
|
|
|
|
|
|
2022-08-16 16:51:26 -07:00
|
|
|
|
def __str__(self):
|
|
|
|
|
return str(self._value)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
try:
|
|
|
|
|
return self.shape[0]
|
|
|
|
|
except IndexError as err:
|
|
|
|
|
raise TypeError("len() of unsized object") from err # same as numpy error
|
|
|
|
|
|
|
|
|
|
def __bool__(self):
|
2024-01-02 19:26:09 -08:00
|
|
|
|
core.check_bool_conversion(self)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
return bool(self._value)
|
|
|
|
|
|
|
|
|
|
def __float__(self):
|
2023-09-19 09:00:19 -07:00
|
|
|
|
core.check_scalar_conversion(self)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
return self._value.__float__()
|
|
|
|
|
|
|
|
|
|
def __int__(self):
|
2023-09-19 09:00:19 -07:00
|
|
|
|
core.check_scalar_conversion(self)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
return self._value.__int__()
|
|
|
|
|
|
|
|
|
|
def __complex__(self):
|
2023-09-19 09:00:19 -07:00
|
|
|
|
core.check_scalar_conversion(self)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
return self._value.__complex__()
|
|
|
|
|
|
|
|
|
|
def __hex__(self):
|
2023-09-19 09:00:19 -07:00
|
|
|
|
core.check_integer_conversion(self)
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return hex(self._value)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
|
|
|
|
def __oct__(self):
|
2023-09-19 09:00:19 -07:00
|
|
|
|
core.check_integer_conversion(self)
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return oct(self._value)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
|
|
|
|
def __index__(self):
|
2023-09-19 09:00:19 -07:00
|
|
|
|
core.check_integer_conversion(self)
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return op.index(self._value)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def tobytes(self, order="C"):
|
2022-08-16 16:51:26 -07:00
|
|
|
|
return self._value.tobytes(order)
|
|
|
|
|
|
|
|
|
|
def tolist(self):
|
|
|
|
|
return self._value.tolist()
|
|
|
|
|
|
|
|
|
|
def __format__(self, format_spec):
|
|
|
|
|
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
|
|
|
|
|
if self.ndim == 0:
|
|
|
|
|
return format(self._value[()], format_spec)
|
|
|
|
|
else:
|
|
|
|
|
return format(self._value, format_spec)
|
|
|
|
|
|
2022-09-09 14:24:39 -07:00
|
|
|
|
def __getitem__(self, idx):
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
from jax._src.lax import lax
|
2025-02-12 11:52:11 -08:00
|
|
|
|
from jax._src.numpy import indexing
|
2022-09-09 14:24:39 -07:00
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
|
2023-03-13 12:26:35 -07:00
|
|
|
|
if isinstance(self.sharding, PmapSharding):
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
if config.pmap_no_rank_reduction.value:
|
|
|
|
|
cidx = idx if isinstance(idx, tuple) else (idx,)
|
|
|
|
|
|
|
|
|
|
padded_cidx = tuple(
|
|
|
|
|
slice(i, i + 1, None) if isinstance(i, int) else i for i in cidx
|
|
|
|
|
) + (slice(None),) * (len(self.shape) - len(cidx))
|
2022-09-09 14:24:39 -07:00
|
|
|
|
else:
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
if not isinstance(idx, tuple):
|
|
|
|
|
padded_cidx = (idx,) + (slice(None),) * (len(self.shape) - 1)
|
|
|
|
|
else:
|
|
|
|
|
padded_cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
|
|
|
|
|
|
2024-01-05 08:50:02 -08:00
|
|
|
|
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
|
|
|
|
|
try:
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
arr_idx = indices.index(padded_cidx)
|
2024-01-05 08:50:02 -08:00
|
|
|
|
except ValueError:
|
|
|
|
|
arr_idx = None
|
|
|
|
|
if arr_idx is not None:
|
2024-08-09 14:40:08 -07:00
|
|
|
|
out = self._arrays[arr_idx]
|
|
|
|
|
sharding = SingleDeviceSharding(_get_device(out))
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
|
|
|
|
|
if config.pmap_no_rank_reduction.value:
|
|
|
|
|
# If cidx was the index of a single shard, then it corresponds to one
|
|
|
|
|
# shard of the chunked dimension.
|
|
|
|
|
dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int))
|
2024-08-09 14:40:08 -07:00
|
|
|
|
# Squeeze on committed arrays to avoid data movement to shard 0.
|
|
|
|
|
out = lax.squeeze(out, dimensions=dims)
|
|
|
|
|
|
|
|
|
|
return ArrayImpl(
|
|
|
|
|
out.aval, sharding, [out], committed=False, _skip_checks=True)
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
|
2025-02-12 11:52:11 -08:00
|
|
|
|
return indexing.rewriting_take(self, idx)
|
2022-09-09 14:24:39 -07:00
|
|
|
|
|
2022-08-16 16:51:26 -07:00
|
|
|
|
def __iter__(self):
|
|
|
|
|
if self.ndim == 0:
|
|
|
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
|
|
|
else:
|
2022-10-08 19:23:32 -07:00
|
|
|
|
assert self.is_fully_replicated or self.is_fully_addressable
|
2022-10-11 13:08:54 -07:00
|
|
|
|
if dispatch.is_single_device_sharding(self.sharding) or self.is_fully_replicated:
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())
|
2022-09-09 20:41:12 -07:00
|
|
|
|
elif isinstance(self.sharding, PmapSharding):
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return (self[i] for i in range(self.shape[0]))
|
2022-09-01 19:53:58 -07:00
|
|
|
|
else:
|
2022-09-09 20:41:12 -07:00
|
|
|
|
# TODO(yashkatariya): Don't bounce to host and use `_chunk_iter` path
|
2022-10-03 22:28:26 -07:00
|
|
|
|
# here after uneven partitioning support is added.
|
2022-11-30 15:25:21 -08:00
|
|
|
|
return (api.device_put(self._value[i]) for i in range(self.shape[0]))
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
|
@property
|
2022-08-24 20:41:48 -07:00
|
|
|
|
def is_fully_replicated(self) -> bool:
|
2023-04-14 13:55:52 -07:00
|
|
|
|
return self.sharding.is_fully_replicated
|
2022-08-24 20:41:48 -07:00
|
|
|
|
|
2022-06-17 13:11:52 -07:00
|
|
|
|
def __repr__(self):
|
2022-09-28 08:57:07 -07:00
|
|
|
|
prefix = 'Array('
|
2022-08-17 12:25:14 -07:00
|
|
|
|
if self.aval is not None and self.aval.weak_type:
|
|
|
|
|
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
|
|
|
|
|
else:
|
|
|
|
|
dtype_str = f'dtype={self.dtype.name})'
|
2022-06-17 13:11:52 -07:00
|
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
|
if self.is_fully_addressable or self.is_fully_replicated:
|
2022-06-17 13:11:52 -07:00
|
|
|
|
line_width = np.get_printoptions()["linewidth"]
|
2024-02-05 13:18:33 -08:00
|
|
|
|
if self.size == 0:
|
|
|
|
|
s = f"[], shape={self.shape}"
|
|
|
|
|
else:
|
|
|
|
|
s = np.array2string(self._value, prefix=prefix, suffix=',',
|
|
|
|
|
separator=', ', max_line_width=line_width)
|
2022-06-17 13:11:52 -07:00
|
|
|
|
last_line_len = len(s) - s.rfind('\n') + 1
|
|
|
|
|
sep = ' '
|
|
|
|
|
if last_line_len + len(dtype_str) + 1 > line_width:
|
|
|
|
|
sep = ' ' * len(prefix)
|
|
|
|
|
return f"{prefix}{s},{sep}{dtype_str}"
|
|
|
|
|
else:
|
2022-08-23 19:48:59 -07:00
|
|
|
|
return f"{prefix}{self.shape}, {dtype_str}"
|
2022-06-17 13:11:52 -07:00
|
|
|
|
|
2023-09-05 17:27:47 -07:00
|
|
|
|
@property
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def is_fully_addressable(self) -> bool:
|
2023-08-20 18:56:50 -07:00
|
|
|
|
"""Is this Array fully addressable?
|
|
|
|
|
|
|
|
|
|
A jax.Array is fully addressable if the current process can address all of
|
|
|
|
|
the devices named in the :class:`Sharding`. ``is_fully_addressable`` is
|
|
|
|
|
equivalent to "is_local" in multi-process JAX.
|
|
|
|
|
|
|
|
|
|
Note that fully replicated is not equal to fully addressable i.e.
|
|
|
|
|
a jax.Array which is fully replicated can span across multiple hosts and is
|
|
|
|
|
not fully addressable.
|
|
|
|
|
"""
|
2022-10-08 19:23:32 -07:00
|
|
|
|
return self.sharding.is_fully_addressable
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2024-03-05 09:31:16 -08:00
|
|
|
|
def __array__(self, dtype=None, context=None, copy=None):
|
|
|
|
|
# copy argument is supported by np.asarray starting in numpy 2.0
|
|
|
|
|
kwds = {} if copy is None else {'copy': copy}
|
|
|
|
|
return np.asarray(self._value, dtype=dtype, **kwds)
|
2022-06-13 18:07:55 -07:00
|
|
|
|
|
2024-04-11 16:44:19 +00:00
|
|
|
|
def __dlpack__(self, *, stream: int | Any | None = None,
|
|
|
|
|
max_version: tuple[int, int] | None = None,
|
|
|
|
|
dl_device: tuple[DLDeviceType, int] | None = None,
|
|
|
|
|
copy: bool | None = None):
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
2024-04-11 16:44:19 +00:00
|
|
|
|
|
|
|
|
|
device_set = self.sharding.device_set
|
|
|
|
|
if len(device_set) > 1:
|
|
|
|
|
raise BufferError(
|
|
|
|
|
"to_dlpack can only pack a dlpack tensor from an array on a singular "
|
|
|
|
|
f"device, but an array with a Sharding over {len(device_set)} devices "
|
|
|
|
|
"was provided."
|
|
|
|
|
)
|
|
|
|
|
device, = device_set
|
|
|
|
|
return to_dlpack(self, stream=stream,
|
|
|
|
|
max_version=max_version,
|
|
|
|
|
src_device=device,
|
|
|
|
|
dl_device=dl_device,
|
|
|
|
|
copy=copy)
|
2023-08-18 14:19:49 -07:00
|
|
|
|
|
2023-09-08 09:18:38 -04:00
|
|
|
|
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
|
2023-08-18 14:19:49 -07:00
|
|
|
|
if len(self._arrays) != 1:
|
2024-03-12 12:56:22 +00:00
|
|
|
|
raise BufferError("__dlpack__ only supported for unsharded arrays.")
|
2023-08-18 14:19:49 -07:00
|
|
|
|
|
2023-08-29 16:38:34 -07:00
|
|
|
|
from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top
|
|
|
|
|
|
2023-08-18 14:19:49 -07:00
|
|
|
|
if self.platform() == "cpu":
|
|
|
|
|
return DLDeviceType.kDLCPU, 0
|
|
|
|
|
|
|
|
|
|
elif self.platform() == "gpu":
|
2023-11-29 16:52:09 -08:00
|
|
|
|
platform_version = _get_device(self).client.platform_version
|
2023-08-18 14:19:49 -07:00
|
|
|
|
if "cuda" in platform_version:
|
|
|
|
|
dl_device_type = DLDeviceType.kDLCUDA
|
|
|
|
|
elif "rocm" in platform_version:
|
|
|
|
|
dl_device_type = DLDeviceType.kDLROCM
|
|
|
|
|
else:
|
2024-03-12 12:56:22 +00:00
|
|
|
|
raise BufferError("Unknown GPU platform for __dlpack__: "
|
2023-08-18 14:19:49 -07:00
|
|
|
|
f"{platform_version}")
|
|
|
|
|
|
2023-11-29 16:52:09 -08:00
|
|
|
|
local_hardware_id = _get_device(self).local_hardware_id
|
2023-08-18 14:19:49 -07:00
|
|
|
|
if local_hardware_id is None:
|
2024-03-12 12:56:22 +00:00
|
|
|
|
raise BufferError("Couldn't get local_hardware_id for __dlpack__")
|
2023-08-18 14:19:49 -07:00
|
|
|
|
|
|
|
|
|
return dl_device_type, local_hardware_id
|
|
|
|
|
|
|
|
|
|
else:
|
2024-03-12 12:56:22 +00:00
|
|
|
|
raise BufferError(
|
2023-08-18 14:19:49 -07:00
|
|
|
|
"__dlpack__ device only supported for CPU and GPU, got platform: "
|
|
|
|
|
f"{self.platform()}"
|
|
|
|
|
)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
|
|
|
|
|
def __reduce__(self):
|
2024-05-17 09:46:36 +01:00
|
|
|
|
fun, args, arr_state = self._value.__reduce__()
|
2024-07-25 00:02:55 +00:00
|
|
|
|
aval_state = {'weak_type': self.aval.weak_type}
|
2022-08-18 15:58:40 -07:00
|
|
|
|
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
|
|
|
|
|
2023-03-28 12:43:32 -07:00
|
|
|
|
@use_cpp_method()
|
2022-08-29 22:02:32 -07:00
|
|
|
|
def unsafe_buffer_pointer(self):
|
2022-12-02 15:10:56 -08:00
|
|
|
|
if len(self._arrays) != 1:
|
|
|
|
|
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
|
|
|
|
|
" arrays.")
|
2022-08-29 22:02:32 -07:00
|
|
|
|
return self._arrays[0].unsafe_buffer_pointer()
|
|
|
|
|
|
|
|
|
|
@property
|
2023-03-28 12:43:32 -07:00
|
|
|
|
@use_cpp_method()
|
2022-08-29 22:02:32 -07:00
|
|
|
|
def __cuda_array_interface__(self):
|
2022-12-02 15:10:56 -08:00
|
|
|
|
if len(self._arrays) != 1:
|
|
|
|
|
raise ValueError("__cuda_array_interface__() is supported only for "
|
|
|
|
|
"unsharded arrays.")
|
2022-08-29 22:02:32 -07:00
|
|
|
|
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
|
|
|
|
|
2023-03-28 12:43:32 -07:00
|
|
|
|
@use_cpp_method()
|
2022-12-02 15:10:56 -08:00
|
|
|
|
def on_device_size_in_bytes(self):
|
|
|
|
|
"""Returns the total global on-device size of the array in bytes."""
|
2023-02-17 11:52:08 -08:00
|
|
|
|
arr = self._arrays[0]
|
2024-05-17 09:46:36 +01:00
|
|
|
|
per_shard_size = arr.on_device_size_in_bytes()
|
2024-08-14 09:02:20 -07:00
|
|
|
|
return per_shard_size * self.sharding.num_devices
|
2022-12-02 15:10:56 -08:00
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def devices(self) -> set[Device]:
|
2022-08-18 15:58:40 -07:00
|
|
|
|
self._check_if_deleted()
|
2023-03-09 20:42:45 -08:00
|
|
|
|
return self.sharding.device_set
|
2022-08-17 12:25:14 -07:00
|
|
|
|
|
2022-09-15 13:26:57 -07:00
|
|
|
|
@property
|
2024-04-24 17:26:38 -07:00
|
|
|
|
def device_buffer(self):
|
|
|
|
|
raise AttributeError(
|
|
|
|
|
"arr.device_buffer has been deprecated. Use arr.addressable_data(0)")
|
2022-09-15 13:26:57 -07:00
|
|
|
|
|
|
|
|
|
@property
|
2024-04-24 17:26:38 -07:00
|
|
|
|
def device_buffers(self):
|
|
|
|
|
raise AttributeError(
|
|
|
|
|
"arr.device_buffers has been deprecated. Use [x.data for x in arr.addressable_shards]")
|
2022-09-15 13:26:57 -07:00
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
def addressable_data(self, index: int) -> ArrayImpl:
|
2022-09-21 18:18:57 -07:00
|
|
|
|
self._check_if_deleted()
|
2023-06-01 09:36:32 -07:00
|
|
|
|
if self.is_fully_replicated:
|
2023-04-17 10:05:01 -07:00
|
|
|
|
return self._fully_replicated_shard()
|
2023-03-29 12:58:34 -07:00
|
|
|
|
return self._arrays[index]
|
2022-09-21 18:18:57 -07:00
|
|
|
|
|
2022-11-29 16:39:45 -08:00
|
|
|
|
@functools.cached_property
|
2022-06-06 18:44:45 -07:00
|
|
|
|
def addressable_shards(self) -> Sequence[Shard]:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-06-06 18:44:45 -07:00
|
|
|
|
out = []
|
2023-03-29 12:58:34 -07:00
|
|
|
|
for a in self._arrays:
|
2023-11-29 16:52:09 -08:00
|
|
|
|
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
|
2022-10-31 09:07:28 -07:00
|
|
|
|
return out
|
|
|
|
|
|
2024-04-08 13:30:27 -07:00
|
|
|
|
@property
|
2024-03-25 10:07:55 -07:00
|
|
|
|
def layout(self):
|
2024-04-05 20:08:48 -07:00
|
|
|
|
# TODO(yashkatariya): Remove the deleted check from here.
|
|
|
|
|
if self.is_deleted():
|
|
|
|
|
return Layout(None, self.sharding)
|
2024-03-25 10:07:55 -07:00
|
|
|
|
try:
|
2024-06-27 16:46:44 -07:00
|
|
|
|
return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout),
|
|
|
|
|
self.sharding)
|
2024-03-25 10:07:55 -07:00
|
|
|
|
except xe.XlaRuntimeError as e:
|
|
|
|
|
msg, *_ = e.args
|
|
|
|
|
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
2024-04-05 20:08:48 -07:00
|
|
|
|
return Layout(None, self.sharding)
|
2024-03-25 10:07:55 -07:00
|
|
|
|
else:
|
|
|
|
|
raise
|
|
|
|
|
|
2022-10-31 09:07:28 -07:00
|
|
|
|
@property
|
|
|
|
|
def global_shards(self) -> Sequence[Shard]:
|
|
|
|
|
"""Returns list of all `Shard`s of the Array across all devices.
|
|
|
|
|
|
|
|
|
|
The result includes shards that are not addressable by the current process.
|
|
|
|
|
If a `Shard` is not addressable, then its `data` will be `None`.
|
|
|
|
|
"""
|
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
if self.is_fully_addressable: # pylint: disable=using-constant-test
|
|
|
|
|
return self.addressable_shards
|
|
|
|
|
|
|
|
|
|
out = []
|
2023-11-29 16:52:09 -08:00
|
|
|
|
device_id_to_buffer = {_get_device(a).id: a for a in self._arrays}
|
2022-10-31 09:07:28 -07:00
|
|
|
|
for global_d in self.sharding.device_set:
|
|
|
|
|
if device_id_to_buffer.get(global_d.id, None) is not None:
|
2023-03-29 12:58:34 -07:00
|
|
|
|
array = device_id_to_buffer[global_d.id]
|
2022-10-31 09:07:28 -07:00
|
|
|
|
else:
|
|
|
|
|
array = None
|
|
|
|
|
out.append(Shard(global_d, self.sharding, self.shape, array))
|
2022-06-06 18:44:45 -07:00
|
|
|
|
return out
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2023-03-28 12:43:32 -07:00
|
|
|
|
@use_cpp_method()
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def delete(self):
|
|
|
|
|
if self._arrays is None:
|
|
|
|
|
return
|
|
|
|
|
for buf in self._arrays:
|
|
|
|
|
buf.delete()
|
|
|
|
|
self._arrays = None
|
|
|
|
|
self._npy_value = None
|
|
|
|
|
|
2023-02-17 11:52:08 -08:00
|
|
|
|
@use_cpp_method()
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def is_deleted(self):
|
2022-09-08 14:39:12 -07:00
|
|
|
|
if self._arrays is None:
|
|
|
|
|
return True
|
|
|
|
|
# This path is taken when a view of `Array` is created and the original
|
|
|
|
|
# Array is deleted. In that case, the buffers the view represents also get
|
|
|
|
|
# deleted.
|
|
|
|
|
return any(buf.is_deleted() for buf in self._arrays)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def _check_if_deleted(self):
|
2022-10-20 08:28:47 -07:00
|
|
|
|
if self.is_deleted():
|
2023-08-24 01:05:36 -07:00
|
|
|
|
raise RuntimeError(
|
2023-08-24 10:37:55 -07:00
|
|
|
|
f"Array has been deleted with shape={self.aval.str_short()}.")
|
2022-06-13 18:07:55 -07:00
|
|
|
|
|
2023-02-17 11:52:08 -08:00
|
|
|
|
@use_cpp_method()
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def block_until_ready(self):
|
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
db.block_until_ready()
|
|
|
|
|
return self
|
|
|
|
|
|
2025-02-24 17:45:19 -05:00
|
|
|
|
@use_cpp_method()
|
|
|
|
|
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
|
|
|
|
|
... # pytype: disable=bad-return-type
|
2023-03-13 17:51:09 -07:00
|
|
|
|
|
2023-03-28 12:43:32 -07:00
|
|
|
|
@use_cpp_method()
|
2023-03-13 17:51:09 -07:00
|
|
|
|
def _copy_single_device_array_to_host_async(self):
|
|
|
|
|
self._arrays[0].copy_to_host_async()
|
|
|
|
|
|
2023-03-15 08:41:47 -07:00
|
|
|
|
@profiler.annotate_function
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def copy_to_host_async(self):
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
if self._npy_value is None:
|
2023-02-17 11:52:08 -08:00
|
|
|
|
if self.is_fully_replicated:
|
2023-03-13 17:51:09 -07:00
|
|
|
|
self._copy_single_device_array_to_host_async()
|
2023-03-04 18:06:26 +00:00
|
|
|
|
return
|
2024-04-15 08:29:02 -07:00
|
|
|
|
for i, _ in _cached_index_calc(self.sharding, self.shape):
|
|
|
|
|
self._arrays[i]._copy_single_device_array_to_host_async()
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
@property
|
2023-03-15 08:41:47 -07:00
|
|
|
|
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def _value(self) -> np.ndarray:
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-08-23 19:48:59 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
if self._npy_value is None:
|
2022-10-08 19:23:32 -07:00
|
|
|
|
if self.is_fully_replicated:
|
2025-02-13 09:34:32 -08:00
|
|
|
|
npy_value, did_copy = self._single_device_array_to_np_array_did_copy()
|
|
|
|
|
npy_value.flags.writeable = False
|
|
|
|
|
if did_copy:
|
|
|
|
|
self._npy_value = npy_value
|
|
|
|
|
return npy_value
|
2022-08-23 19:48:59 -07:00
|
|
|
|
|
2023-03-24 13:21:20 -07:00
|
|
|
|
# TODO(yashkatariya): Merge `_process_has_full_value_in_mcjax` with
|
|
|
|
|
# is_fully_addressable.
|
|
|
|
|
if (not self.is_fully_addressable and
|
|
|
|
|
not _process_has_full_value_in_mcjax(self.sharding, self.shape)):
|
2024-05-23 12:40:40 -07:00
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Fetching value for `jax.Array` that spans non-addressable"
|
|
|
|
|
" (non process local) devices is not possible. You can use"
|
|
|
|
|
" `jax.experimental.multihost_utils.process_allgather` to print the"
|
|
|
|
|
" global array or use `.addressable_shards` method of jax.Array to"
|
|
|
|
|
" inspect the addressable (process local) shards."
|
|
|
|
|
)
|
2022-08-23 19:48:59 -07:00
|
|
|
|
|
2024-04-15 08:29:02 -07:00
|
|
|
|
for i, _ in _cached_index_calc(self.sharding, self.shape):
|
|
|
|
|
self._arrays[i]._copy_single_device_array_to_host_async()
|
2023-03-24 13:21:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
npy_value = np.empty(self.shape, self.dtype)
|
2024-04-15 08:29:02 -07:00
|
|
|
|
for i, ind in _cached_index_calc(self.sharding, self.shape):
|
2025-02-13 09:34:32 -08:00
|
|
|
|
npy_value[ind], _ = self._arrays[i]._single_device_array_to_np_array_did_copy()
|
2024-05-17 09:46:36 +01:00
|
|
|
|
self._npy_value = npy_value
|
2022-09-15 13:26:57 -07:00
|
|
|
|
self._npy_value.flags.writeable = False
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return self._npy_value
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2023-03-14 14:19:25 -07:00
|
|
|
|
|
|
|
|
|
# TODO(b/273265390): ideally we would write this as a decorator on the ArrayImpl
|
|
|
|
|
# class, however this triggers a pytype bug. Workaround: apply the decorator
|
|
|
|
|
# after the fact.
|
|
|
|
|
if not TYPE_CHECKING:
|
|
|
|
|
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)
|
|
|
|
|
|
|
|
|
|
|
2024-06-27 20:59:25 -07:00
|
|
|
|
def _get_shape_from_index(slc: Index, shape: Shape) -> Shape:
|
|
|
|
|
return tuple(
|
|
|
|
|
(s.stop or dim) - (s.start or 0)
|
|
|
|
|
for s, dim in safe_zip(slc, shape)
|
|
|
|
|
if isinstance(s, slice) # If element is int, this dimension is reduced
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2024-03-25 10:07:55 -07:00
|
|
|
|
# explicitly set to be unhashable.
|
2022-09-26 16:17:26 -07:00
|
|
|
|
setattr(ArrayImpl, "__hash__", None)
|
|
|
|
|
setattr(ArrayImpl, "__array_priority__", 100)
|
2022-08-31 15:06:58 -07:00
|
|
|
|
|
2024-05-15 22:06:11 -07:00
|
|
|
|
# TODO(yashkatariya): Remove None from callback input type.
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
def make_array_from_callback(
|
2024-04-15 12:37:46 -07:00
|
|
|
|
shape: Shape, sharding: Sharding | Layout,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
data_callback: Callable[[Index | None], ArrayLike]) -> ArrayImpl:
|
2024-05-15 22:06:11 -07:00
|
|
|
|
# pyformat: disable
|
2022-11-11 15:20:27 -08:00
|
|
|
|
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
|
|
|
|
|
|
|
|
|
|
``data_callback`` is used to fetch the data for each addressable shard of the
|
2024-01-31 15:13:33 -08:00
|
|
|
|
returned ``jax.Array``. This function must return concrete arrays, meaning that
|
|
|
|
|
``make_array_from_callback`` has limited compatibility with JAX transformations
|
|
|
|
|
like :func:`jit` or :func:`vmap`.
|
2022-11-11 15:20:27 -08:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shape : Shape of the ``jax.Array``.
|
|
|
|
|
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
|
|
|
|
|
laid out across devices.
|
|
|
|
|
data_callback : Callback that takes indices into the global array value as
|
|
|
|
|
input and returns the corresponding data of the global array value.
|
|
|
|
|
The data can be returned as any array-like object, e.g. a ``numpy.ndarray``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A ``jax.Array`` via data fetched from ``data_callback``.
|
|
|
|
|
|
2024-06-21 11:28:35 -04:00
|
|
|
|
Examples:
|
2022-11-11 15:20:27 -08:00
|
|
|
|
|
2023-02-28 12:40:30 -08:00
|
|
|
|
>>> import math
|
2023-02-03 14:28:07 -08:00
|
|
|
|
>>> from jax.sharding import Mesh
|
|
|
|
|
>>> from jax.sharding import PartitionSpec as P
|
2022-11-11 15:20:27 -08:00
|
|
|
|
>>> import numpy as np
|
|
|
|
|
...
|
|
|
|
|
>>> input_shape = (8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
|
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
2022-11-11 15:20:27 -08:00
|
|
|
|
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
|
|
|
|
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
|
|
|
|
|
...
|
|
|
|
|
>>> def cb(index):
|
|
|
|
|
... return global_input_data[index]
|
|
|
|
|
...
|
|
|
|
|
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
|
|
|
|
|
>>> arr.addressable_data(0).shape
|
|
|
|
|
(4, 2)
|
|
|
|
|
"""
|
2024-05-15 22:06:11 -07:00
|
|
|
|
# pyformat: enable
|
2024-04-15 12:37:46 -07:00
|
|
|
|
dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
|
|
|
|
|
if isinstance(dll, AutoLayout):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"`DeviceLocalLayout.AUTO` cannot be used in place of a device-local"
|
|
|
|
|
f" layout when calling `jax.make_array_from_callback`. Got {sharding}")
|
2025-02-13 18:05:27 +00:00
|
|
|
|
sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
|
2024-04-15 12:37:46 -07:00
|
|
|
|
if not isinstance(sharding, Sharding):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"sharding should be an instance of `jax.sharding`. Got {sharding} of"
|
|
|
|
|
f" type {type(sharding)}")
|
|
|
|
|
|
2024-06-27 20:59:25 -07:00
|
|
|
|
def get_data(index: Index | None) -> ArrayImpl | np.ndarray:
|
|
|
|
|
# Perhaps cache on index here, then we can unify fully_replicated
|
|
|
|
|
# and non-fully_replicated cases below and become faster for
|
|
|
|
|
# partially replicated cases.
|
|
|
|
|
assert index is not None
|
|
|
|
|
r = data_callback(index)
|
|
|
|
|
if isinstance(r, core.Tracer):
|
|
|
|
|
raise errors.UnexpectedTracerError(
|
|
|
|
|
"jax.make_array_from_callback cannot be called within a traced"
|
|
|
|
|
" context."
|
|
|
|
|
)
|
|
|
|
|
# Value can be python scalar, resolve it into something with dtype.
|
|
|
|
|
return xla.canonicalize_dtype(r)
|
|
|
|
|
|
2023-10-10 19:01:26 -07:00
|
|
|
|
if sharding.is_fully_replicated:
|
2024-04-14 14:35:13 -07:00
|
|
|
|
devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore
|
2024-06-27 20:59:25 -07:00
|
|
|
|
# Only compute data once.
|
|
|
|
|
per_device_values = [get_data((slice(None),) * len(shape))] * len(devices)
|
2023-10-10 19:01:26 -07:00
|
|
|
|
else:
|
|
|
|
|
device_to_index_map = sharding.addressable_devices_indices_map(shape)
|
|
|
|
|
devices = list(device_to_index_map.keys())
|
2024-06-27 20:59:25 -07:00
|
|
|
|
per_device_values = [
|
|
|
|
|
get_data(device_to_index_map[device]) for device in devices
|
|
|
|
|
]
|
2023-10-10 19:01:26 -07:00
|
|
|
|
|
2024-06-27 20:59:25 -07:00
|
|
|
|
first_value = per_device_values[0]
|
|
|
|
|
expected_dtype = first_value.dtype
|
|
|
|
|
expected_shape = sharding.shard_shape(shape)
|
2025-01-20 15:12:12 -08:00
|
|
|
|
aval = core.update_aval_with_sharding(
|
|
|
|
|
core.ShapedArray(shape, expected_dtype), sharding)
|
2024-06-27 20:59:25 -07:00
|
|
|
|
_validate_shape_and_dtype_for_per_device_arrays(
|
|
|
|
|
per_device_values,
|
|
|
|
|
expected_shape=expected_shape,
|
|
|
|
|
aval=aval,
|
|
|
|
|
sharding=sharding,
|
|
|
|
|
)
|
2024-04-15 12:37:46 -07:00
|
|
|
|
if (isinstance(first_value, ArrayImpl)
|
|
|
|
|
and first_value._committed
|
|
|
|
|
and sharding.is_fully_replicated
|
|
|
|
|
and first_value.is_fully_replicated
|
|
|
|
|
and first_value.sharding._device_assignment == tuple(devices)
|
2024-08-28 11:05:45 -07:00
|
|
|
|
and first_value.layout.device_local_layout == dll):
|
2024-04-14 14:35:13 -07:00
|
|
|
|
return first_value
|
|
|
|
|
|
2024-06-27 20:59:25 -07:00
|
|
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
|
|
|
|
# TODO(yashkatariya): Can this also use batched_device_put?
|
|
|
|
|
arrays = api.device_put(per_device_values, devices)
|
|
|
|
|
return aval.dtype._rules.make_sharded_array(
|
|
|
|
|
aval, sharding, arrays, committed=True
|
|
|
|
|
)
|
|
|
|
|
|
2024-04-15 12:37:46 -07:00
|
|
|
|
if dll is not None:
|
|
|
|
|
devices = [Layout(dll, SingleDeviceSharding(d)) for d in devices]
|
2024-06-27 20:59:25 -07:00
|
|
|
|
# pxla.batched_device_put doesn't support Layout... Take the slow route
|
|
|
|
|
arrays = api.device_put(per_device_values, devices)
|
|
|
|
|
return ArrayImpl(aval, sharding, arrays, committed=True)
|
|
|
|
|
|
|
|
|
|
if isinstance(first_value, ArrayImpl) and len(first_value.devices()) > 1:
|
|
|
|
|
# The output of the callback is already a sharded array, move it to
|
|
|
|
|
# to target device.
|
|
|
|
|
per_device_values = api.device_put(per_device_values, devices)
|
|
|
|
|
|
|
|
|
|
return pxla.batched_device_put(aval, sharding, per_device_values, devices)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
|
|
|
|
|
2024-05-15 22:06:11 -07:00
|
|
|
|
def make_array_from_process_local_data(
|
|
|
|
|
sharding: Sharding,
|
|
|
|
|
local_data: np.ndarray,
|
2024-06-06 12:40:21 -07:00
|
|
|
|
global_shape: Shape | None = None,
|
2024-05-15 22:06:11 -07:00
|
|
|
|
) -> ArrayImpl:
|
|
|
|
|
# pyformat: disable
|
|
|
|
|
"""Creates distributed tensor using the data available in process.
|
|
|
|
|
|
|
|
|
|
This function is a common special case of `make_array_from_callback`. It
|
|
|
|
|
assumes that the data is available in the process and takes care of the
|
|
|
|
|
index wrangling.
|
|
|
|
|
|
2024-06-06 12:40:21 -07:00
|
|
|
|
The most common case is when the sharding is sharded across the batch
|
|
|
|
|
dimension and each host just loads its corresponding sub-batch. This function
|
|
|
|
|
supports more general cases as well, such as mixed multi-host and multi-axis
|
|
|
|
|
replication and sharding but you would need to compute the size and the
|
|
|
|
|
contents of process-local data correctly to satisfy the sharding constraints.
|
2024-06-05 16:57:35 -07:00
|
|
|
|
|
2024-06-06 12:40:21 -07:00
|
|
|
|
In particular, if any two hosts are replicas, host_local_data should be
|
|
|
|
|
identical as well.
|
2024-06-05 16:57:35 -07:00
|
|
|
|
|
2024-06-06 12:40:21 -07:00
|
|
|
|
The global_shape is optional. If not provided it will be be inferred from
|
|
|
|
|
the local_data and sharding, under the assumption that
|
|
|
|
|
each host represents only their own data for uniform sharding. If sharding
|
|
|
|
|
is non-uniform, (see note below) an exception will be raised.
|
2024-06-05 16:57:35 -07:00
|
|
|
|
|
2024-06-06 12:40:21 -07:00
|
|
|
|
Setting global_shape explicitly allows for finer grain control and works with
|
|
|
|
|
non-uniform shardings. Each dimension of global_shape must either match
|
|
|
|
|
host_local_data, or match the inferred global shape of the sharding (in which
|
|
|
|
|
case it is equivalent to setting it to None, but is more explicit).
|
|
|
|
|
|
|
|
|
|
For example if dimension `i` is fully sharded then this size would be
|
|
|
|
|
`per_device_shape[i] * jax.local_device_count()`. Each device will be mapped
|
|
|
|
|
into local slice of `local_data` array. For example, if given process
|
|
|
|
|
addresses slices (8, 12) and (24, 28), then these slices will be mapped
|
|
|
|
|
into (0, 4) and (4, 8) of the `local_data`.
|
|
|
|
|
|
|
|
|
|
For each dimension where global_shapes matches local_shape, each device
|
|
|
|
|
will lookup the slice in the local_data. For example if
|
|
|
|
|
global_shape == local_data.shape, the local data is assumed to be the
|
|
|
|
|
actual target array that will be sharded into device.
|
|
|
|
|
|
|
|
|
|
If global_shape is the same as local_data.shape, then the data must
|
|
|
|
|
be the same across all hosts.
|
2024-05-15 22:06:11 -07:00
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from jax.sharding import PartitionSpec as P
|
|
|
|
|
>>> mesh_rows = 2
|
|
|
|
|
>>> mesh_cols = jax.device_count() // 2
|
|
|
|
|
...
|
|
|
|
|
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
|
|
|
|
|
|
|
|
|
|
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
|
|
|
|
|
>>> rows_per_device = 2
|
|
|
|
|
>>> feature_length = 32
|
|
|
|
|
>>> per_device_shape = (rows_per_device, feature_length)
|
|
|
|
|
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
|
|
|
|
|
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
|
|
|
|
|
>>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays
|
|
|
|
|
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
|
|
|
|
|
>>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape)
|
|
|
|
|
...
|
|
|
|
|
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
|
|
|
|
|
>>> assert output_global_array.shape == global_shape
|
|
|
|
|
|
2024-06-06 12:40:21 -07:00
|
|
|
|
NB: While most shardings are uniform, It is possible to design am exotic
|
|
|
|
|
sharding mesh where each process's devices will be arranged in a non-grid
|
|
|
|
|
like pattern in some dimensions, or for indices to overlap non-trivially.
|
|
|
|
|
Such sharding is called "non-uniform" in those dimensions. In that case,
|
|
|
|
|
the global shape along those directions must match local shape as there is
|
|
|
|
|
no meaningful way to represent all needed
|
|
|
|
|
per-process data in non-overlapping fashion. For example for global_shape 4x4
|
|
|
|
|
if sharding looks like this:
|
|
|
|
|
|
|
|
|
|
0123
|
|
|
|
|
2103
|
|
|
|
|
4675
|
|
|
|
|
4567
|
|
|
|
|
|
|
|
|
|
with 4 processes, containing devices (0,1), (2, 3), (4, 5), (6, 7) respectively.
|
|
|
|
|
Then the data for each host look like
|
|
|
|
|
|
|
|
|
|
xx.. ..xx .... ....
|
|
|
|
|
.xx. x..x .... ....
|
|
|
|
|
.... .... x..x .xx.
|
|
|
|
|
.... .... xx.. ..xx
|
|
|
|
|
|
|
|
|
|
the sharding is uniform on rows (each host requires either rows 1-2, or rows 3-4)
|
|
|
|
|
and non-uniform on columns (hosts require overlapping but not matching
|
|
|
|
|
set of columns). Thus local data must have the shape 2x4 or 4x4
|
|
|
|
|
for all hosts, even though each host can potentially fit into 2x2 shape.
|
|
|
|
|
In this case user must provide global_shape explicitly and for
|
|
|
|
|
local_shape=(2, 4), potentially valid global shapes are (2, 4) and (4, 4).
|
|
|
|
|
|
|
|
|
|
On the other hand for sharding:
|
|
|
|
|
|
|
|
|
|
0213 x.x. .x.x. .... ....
|
|
|
|
|
0213 x.x. .x.x. .... ....
|
|
|
|
|
4657 .... .... .x.x x.x.
|
|
|
|
|
4657 .... .... .x.x x.x.
|
|
|
|
|
|
|
|
|
|
for local_shape=(2, 2) this function can accept a choice of 2x2, 2x4, 4x2
|
|
|
|
|
and 4x4 global shapes. Setting global_shape to None, is equivalent to
|
|
|
|
|
setting it to (4, 4) in this case.
|
|
|
|
|
|
2024-05-15 22:06:11 -07:00
|
|
|
|
Args:
|
2024-09-21 10:22:36 -07:00
|
|
|
|
sharding: Sharding of the global array.
|
|
|
|
|
local_data: Data on the host to be placed on local devices. Each
|
2024-05-15 22:06:11 -07:00
|
|
|
|
dimension should either match global_shape, or match
|
|
|
|
|
num_addressable_indices(dim).
|
2024-09-21 10:22:36 -07:00
|
|
|
|
global_shape: The target shape of the global array. If None,
|
2024-08-16 17:21:10 -07:00
|
|
|
|
will infer from local_data and sharding.
|
2024-05-15 22:06:11 -07:00
|
|
|
|
|
|
|
|
|
Returns:
|
2024-06-06 12:40:21 -07:00
|
|
|
|
Tensor that will have sharding=sharding and of shape global_shape.
|
2024-05-15 22:06:11 -07:00
|
|
|
|
"""
|
|
|
|
|
# pyformat: enable
|
2024-09-21 10:22:36 -07:00
|
|
|
|
if xla_bridge.process_count() == 1:
|
|
|
|
|
return api.device_put(local_data, sharding)
|
|
|
|
|
|
2024-06-06 12:40:21 -07:00
|
|
|
|
# TODO(sandler): consider supporting partially specified global_shape or
|
|
|
|
|
# making local_to_global_shape available in the api.
|
|
|
|
|
local_shape = local_data.shape
|
|
|
|
|
if global_shape is None:
|
|
|
|
|
global_shape = local_to_global_shape(sharding, local_shape) # type: ignore[assignment]
|
|
|
|
|
assert global_shape is not None
|
|
|
|
|
if None in global_shape:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unable to compute global_shape due to non-uniform sharding."
|
|
|
|
|
f" Specify global shape directly. Partially computed {global_shape=}."
|
|
|
|
|
)
|
|
|
|
|
elif None in global_shape:
|
|
|
|
|
raise ValueError(f"{global_shape=} has Nones. This is not supported.")
|
2024-05-15 22:06:11 -07:00
|
|
|
|
full_dim = []
|
|
|
|
|
for i, (data_dim, global_dim) in enumerate(
|
|
|
|
|
zip(local_data.shape, global_shape)
|
|
|
|
|
):
|
|
|
|
|
full_dim.append(data_dim == global_dim)
|
|
|
|
|
if data_dim != global_dim:
|
|
|
|
|
process_slice = num_addressable_indices(sharding, i, global_shape)
|
|
|
|
|
if process_slice != data_dim:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Invalid host data, each dimension should match either global or "
|
|
|
|
|
f"process shape. In dimension {i=}, the process data has {data_dim}"
|
|
|
|
|
f"elements. Process addresses {process_slice} elements and "
|
|
|
|
|
f"{global_shape=}."
|
|
|
|
|
)
|
|
|
|
|
addressable_shards = sharding.addressable_devices_indices_map(global_shape)
|
2024-06-27 20:59:25 -07:00
|
|
|
|
shard = next(iter(addressable_shards.values()))
|
|
|
|
|
assert shard is not None
|
|
|
|
|
shard_shape = _get_shape_from_index(shard, global_shape)
|
2024-05-15 22:06:11 -07:00
|
|
|
|
slices_for_each_dim: list[list[int]] = [[] for _ in global_shape]
|
|
|
|
|
for shard_index in addressable_shards.values():
|
|
|
|
|
assert shard_index is not None
|
|
|
|
|
for i, slc in enumerate(shard_index):
|
|
|
|
|
slices_for_each_dim[i].append(slc.start or 0)
|
|
|
|
|
for i in range(len(global_shape)):
|
|
|
|
|
slices_for_each_dim[i] = sorted(set(slices_for_each_dim[i]))
|
|
|
|
|
|
2024-06-27 20:59:25 -07:00
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
|
|
|
def local_slice(i, start):
|
2024-05-15 22:06:11 -07:00
|
|
|
|
# Looks up the index of this slice in the list of slices for this dimension.
|
|
|
|
|
# This will determine the slice in host_local_data
|
2024-06-27 20:59:25 -07:00
|
|
|
|
start = slices_for_each_dim[i].index(start or 0) * shard_shape[i]
|
2024-05-15 22:06:11 -07:00
|
|
|
|
end = start + shard_shape[i]
|
|
|
|
|
return slice(start, end)
|
|
|
|
|
|
|
|
|
|
def cb(index: Index | None) -> ArrayLike:
|
|
|
|
|
assert index is not None
|
2024-06-27 20:59:25 -07:00
|
|
|
|
data_slice = (
|
|
|
|
|
slc if full_dim[i] else local_slice(i, slc.start)
|
2024-05-15 22:06:11 -07:00
|
|
|
|
for i, slc in enumerate(index)
|
2024-06-27 20:59:25 -07:00
|
|
|
|
)
|
2024-05-15 22:06:11 -07:00
|
|
|
|
return local_data[tuple(data_slice)]
|
|
|
|
|
|
|
|
|
|
return make_array_from_callback(global_shape, sharding, cb)
|
|
|
|
|
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
def make_array_from_single_device_arrays(
|
2025-02-26 16:56:47 -08:00
|
|
|
|
shape: Shape, sharding: Sharding, arrays: Sequence[basearray.Array]
|
2023-03-14 14:19:25 -07:00
|
|
|
|
) -> ArrayImpl:
|
2023-10-15 21:55:10 +00:00
|
|
|
|
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s each on a single device.
|
|
|
|
|
Every device in input ``sharding``\'s mesh must have an array in ``arrays``\s.
|
2022-11-11 15:20:27 -08:00
|
|
|
|
|
|
|
|
|
Args:
|
2023-10-15 21:55:10 +00:00
|
|
|
|
shape : Shape of the output ``jax.Array``. This conveys information already included with
|
|
|
|
|
``sharding`` and ``arrays`` and serves as a double check.
|
|
|
|
|
sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices.
|
|
|
|
|
arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
|
|
|
|
|
must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code,
|
|
|
|
|
each process will call with a different ``arrays`` argument that corresponds to that processes' data.
|
|
|
|
|
These arrays are commonly created via ``jax.device_put``.
|
2022-11-11 15:20:27 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
2023-10-15 21:55:10 +00:00
|
|
|
|
A global ``jax.Array``, sharded as ``sharding``, with shape equal to ``shape``, and with per-device
|
|
|
|
|
contents matching ``arrays``.
|
2022-11-11 15:20:27 -08:00
|
|
|
|
|
2023-10-15 21:55:10 +00:00
|
|
|
|
Examples:
|
|
|
|
|
|
2023-02-28 12:40:30 -08:00
|
|
|
|
>>> import math
|
2023-02-09 05:47:59 -08:00
|
|
|
|
>>> from jax.sharding import Mesh
|
|
|
|
|
>>> from jax.sharding import PartitionSpec as P
|
2022-11-11 15:20:27 -08:00
|
|
|
|
>>> import numpy as np
|
|
|
|
|
...
|
2023-10-15 21:55:10 +00:00
|
|
|
|
>>> mesh_rows = 2
|
|
|
|
|
>>> mesh_cols = jax.device_count() // 2
|
|
|
|
|
...
|
2023-05-04 19:11:26 -07:00
|
|
|
|
>>> global_shape = (8, 8)
|
2023-10-15 21:55:10 +00:00
|
|
|
|
>>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
|
|
|
|
|
>>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
2023-05-04 19:11:26 -07:00
|
|
|
|
>>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape)
|
2022-11-11 15:20:27 -08:00
|
|
|
|
...
|
|
|
|
|
>>> arrays = [
|
2023-10-15 21:55:10 +00:00
|
|
|
|
... jax.device_put(inp_data[index], d)
|
|
|
|
|
... for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
|
2022-11-11 15:20:27 -08:00
|
|
|
|
...
|
2023-05-04 19:11:26 -07:00
|
|
|
|
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
|
2023-10-15 21:55:10 +00:00
|
|
|
|
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
|
2023-05-04 19:11:26 -07:00
|
|
|
|
|
2024-07-12 18:09:27 -07:00
|
|
|
|
For cases where you have a local array and want to convert it to a global
|
|
|
|
|
jax.Array, use ``jax.make_array_from_process_local_data``.
|
2022-11-11 15:20:27 -08:00
|
|
|
|
"""
|
2022-09-26 12:43:13 -07:00
|
|
|
|
# All input arrays should be committed. Checking it is expensive on
|
|
|
|
|
# single-controller systems.
|
2025-01-20 15:12:12 -08:00
|
|
|
|
aval = core.update_aval_with_sharding(
|
2025-02-26 16:56:47 -08:00
|
|
|
|
core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding)
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
2024-04-15 12:37:46 -07:00
|
|
|
|
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
|
|
|
|
|
committed=True)
|
2023-08-18 16:50:36 -04:00
|
|
|
|
# TODO(phawkins): ideally the cast() could be checked.
|
2025-01-25 07:11:18 +00:00
|
|
|
|
try:
|
|
|
|
|
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
|
|
|
|
|
committed=True)
|
|
|
|
|
except TypeError:
|
|
|
|
|
if not isinstance(arrays, Sequence):
|
|
|
|
|
raise TypeError("jax.make_array_from_single_device_arrays `arrays` "
|
|
|
|
|
"argument must be a Sequence (list or tuple), but got "
|
|
|
|
|
f"{type(arrays)}.")
|
|
|
|
|
if any(isinstance(arr, core.Tracer) for arr in arrays):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"jax.make_array_from_single_device_arrays requires a list of concrete"
|
|
|
|
|
f" arrays as input, but got types {set(map(type, arrays))}")
|
|
|
|
|
raise
|
2022-09-26 12:43:13 -07:00
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
2024-12-17 13:47:58 -08:00
|
|
|
|
|
2024-08-29 10:49:30 -07:00
|
|
|
|
def _get_aval_array(self):
|
2025-01-20 15:12:12 -08:00
|
|
|
|
return core.update_aval_with_sharding(self.aval, self.sharding)
|
2024-12-17 13:47:58 -08:00
|
|
|
|
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
|
|
|
|
|
|
2022-12-19 13:13:15 -08:00
|
|
|
|
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
|
|
|
|
basearray.Array.register(ArrayImpl)
|
2022-08-12 12:09:22 -07:00
|
|
|
|
|
|
|
|
|
|
2023-08-17 06:43:31 -07:00
|
|
|
|
def _array_mlir_constant_handler(val):
|
2024-06-29 14:36:03 -07:00
|
|
|
|
try:
|
2024-07-01 08:42:48 -04:00
|
|
|
|
return mlir.ir_constant(val._value)
|
2024-06-29 14:36:03 -07:00
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
# TODO(yashkatariya): Ideally we would catch a custom exception from
|
|
|
|
|
# `_value` function in ArrayImpl instead of checking the error string.
|
|
|
|
|
if 'Fetching value for `jax.Array` that spans non-addressable' in str(e):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Closing over jax.Array that spans non-addressable (non process"
|
|
|
|
|
" local) devices is not allowed. Please pass such arrays as arguments"
|
|
|
|
|
f" to the function. Got jax.Array: {val.aval.str_short()}") from e
|
|
|
|
|
raise
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
2023-07-20 09:43:40 -07:00
|
|
|
|
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
|
|
|
|
# from the input ShardingSpec, rather than the indices. However, this would
|
|
|
|
|
# require duplicating the ordering logic of spec_to_indices, which is more
|
|
|
|
|
# subtle and more likely to change than the index logic we have to support here.
|
|
|
|
|
def as_slice_indices(arr: Any, idx: Index) -> tuple[
|
|
|
|
|
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
|
|
|
"""Returns start_indices, limit_indices, removed_dims"""
|
|
|
|
|
start_indices = [0] * arr.ndim
|
|
|
|
|
limit_indices = list(arr.shape)
|
2023-10-11 12:54:51 -07:00
|
|
|
|
removed_dims: list[int] = []
|
2023-07-20 09:43:40 -07:00
|
|
|
|
|
|
|
|
|
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
|
|
|
|
|
for dim, sub_idx in enumerate(tuple_idx):
|
|
|
|
|
if isinstance(sub_idx, int):
|
|
|
|
|
start_indices[dim] = sub_idx
|
|
|
|
|
limit_indices[dim] = sub_idx + 1
|
|
|
|
|
removed_dims.append(dim)
|
|
|
|
|
elif sub_idx == slice(None):
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(sub_idx, slice), sub_idx
|
|
|
|
|
assert isinstance(sub_idx.start, int), sub_idx
|
|
|
|
|
assert isinstance(sub_idx.stop, int), sub_idx
|
|
|
|
|
start_indices[dim] = sub_idx.start
|
|
|
|
|
limit_indices[dim] = sub_idx.stop
|
|
|
|
|
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims)
|
2023-07-20 09:43:40 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shard_device_array(x, devices, indices, sharding):
|
|
|
|
|
start_indices, limit_indices, removed_dims = unzip3(
|
|
|
|
|
as_slice_indices(x, idx) for idx in indices)
|
2023-08-11 10:29:41 -07:00
|
|
|
|
if sharding.is_fully_replicated:
|
|
|
|
|
shards = [x] * len(devices)
|
|
|
|
|
else:
|
2025-02-18 15:22:06 -08:00
|
|
|
|
# TODO(yashkatariya): Maybe this should be set when we call the handler in
|
|
|
|
|
# InputsHandler.__call__?
|
|
|
|
|
with set_concrete_mesh(None):
|
|
|
|
|
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
2024-12-19 07:06:12 -08:00
|
|
|
|
aval = core.shaped_abstractify(x)
|
2023-08-10 15:25:39 -07:00
|
|
|
|
return pxla.batched_device_put(aval, sharding, shards, devices)
|
2023-07-20 09:43:40 -07:00
|
|
|
|
|
2024-04-12 21:40:47 -07:00
|
|
|
|
|
2023-07-20 09:43:40 -07:00
|
|
|
|
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
|
|
|
|
candidates = defaultdict(list)
|
2024-04-12 21:40:47 -07:00
|
|
|
|
bufs = [buf.data for buf in x.addressable_shards]
|
|
|
|
|
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
|
2023-07-20 09:43:40 -07:00
|
|
|
|
for buf, idx in safe_zip(bufs, arr_indices):
|
2024-07-20 09:08:16 -07:00
|
|
|
|
candidates[hashed_index(idx)].append(buf)
|
2023-07-20 09:43:40 -07:00
|
|
|
|
|
|
|
|
|
bufs = []
|
|
|
|
|
for idx, device in safe_zip(indices, devices):
|
|
|
|
|
# Look up all buffers that contain the correct slice of the logical array.
|
2024-07-20 09:08:16 -07:00
|
|
|
|
candidates_list = candidates[hashed_index(idx)]
|
2023-07-20 09:43:40 -07:00
|
|
|
|
if not candidates_list:
|
2024-11-07 15:50:32 -08:00
|
|
|
|
return pxla.shard_args([sharding], [None], [None], [x._value],
|
2024-08-19 15:10:00 -07:00
|
|
|
|
canonicalize=False)[0]
|
2023-07-20 09:43:40 -07:00
|
|
|
|
# Try to find a candidate buffer already on the correct device,
|
|
|
|
|
# otherwise copy one of them.
|
|
|
|
|
for buf in candidates_list:
|
2023-11-29 16:52:09 -08:00
|
|
|
|
if buf.devices() == {device}:
|
2023-07-20 09:43:40 -07:00
|
|
|
|
bufs.append(buf)
|
|
|
|
|
break
|
|
|
|
|
else:
|
2024-12-09 06:52:25 -08:00
|
|
|
|
bufs.append(candidates_list[-1])
|
2023-07-20 09:43:40 -07:00
|
|
|
|
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
|
|
|
|
|
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
2024-04-22 10:23:47 -07:00
|
|
|
|
def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
|
|
|
|
|
src_indices = src_sharding.addressable_devices_indices_map(shape).values()
|
|
|
|
|
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
|
|
|
|
|
return dst_indices, tuple(src_indices) == tuple(dst_indices)
|
|
|
|
|
|
|
|
|
|
|
2024-11-07 15:50:32 -08:00
|
|
|
|
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
|
2024-12-11 16:54:52 -05:00
|
|
|
|
util.test_event("_array_shard_arg")
|
2024-06-13 13:09:35 -07:00
|
|
|
|
results = []
|
|
|
|
|
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
|
2024-11-07 15:50:32 -08:00
|
|
|
|
batch_cs = []
|
2024-08-19 15:10:00 -07:00
|
|
|
|
|
2024-11-07 15:50:32 -08:00
|
|
|
|
for i, (x, sharding, layout, cs) in enumerate(
|
|
|
|
|
safe_zip(xs, shardings, layouts, copy_semantics)):
|
2024-06-13 13:09:35 -07:00
|
|
|
|
x._check_if_deleted()
|
2024-08-19 15:10:00 -07:00
|
|
|
|
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
|
2024-08-28 11:05:45 -07:00
|
|
|
|
same_layout = (True if layout is None else
|
|
|
|
|
x.layout.device_local_layout == layout)
|
2022-08-19 21:36:43 -07:00
|
|
|
|
|
2024-06-13 13:09:35 -07:00
|
|
|
|
if not x.is_fully_addressable:
|
2024-08-19 15:10:00 -07:00
|
|
|
|
if same_indices and same_layout:
|
2024-06-13 13:09:35 -07:00
|
|
|
|
results.append(x)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Cannot reshard an input that is not fully addressable")
|
2022-10-08 11:39:05 -07:00
|
|
|
|
else:
|
2024-06-13 13:09:35 -07:00
|
|
|
|
devices = sharding._addressable_device_assignment
|
2024-08-19 15:10:00 -07:00
|
|
|
|
if same_indices and same_layout:
|
2024-06-13 13:09:35 -07:00
|
|
|
|
# Add a placeholder result that will be filled in later.
|
|
|
|
|
results.append(None)
|
|
|
|
|
# Accumulate arguments to `batched_copy_array_to_devices_with_sharding`.
|
|
|
|
|
batch_xs.append(x)
|
|
|
|
|
batch_devs.append(list(devices))
|
|
|
|
|
batch_shardings.append(sharding)
|
|
|
|
|
batch_indices.append(i)
|
2024-11-07 15:50:32 -08:00
|
|
|
|
batch_cs.append(cs)
|
2024-06-13 13:09:35 -07:00
|
|
|
|
# Resharding starts here:
|
2024-08-19 15:10:00 -07:00
|
|
|
|
elif not same_layout:
|
|
|
|
|
results.append(api.device_put(x, Layout(layout, sharding)))
|
2024-06-13 13:09:35 -07:00
|
|
|
|
elif dispatch.is_single_device_sharding(x.sharding):
|
|
|
|
|
results.append(shard_device_array(x, devices, indices, sharding))
|
|
|
|
|
else:
|
|
|
|
|
results.append(
|
|
|
|
|
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
|
|
|
|
|
|
2024-12-11 16:54:52 -05:00
|
|
|
|
util.test_event("batched_copy_array")
|
2024-12-09 07:34:26 -08:00
|
|
|
|
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
|
|
|
|
|
batch_xs, batch_devs, batch_shardings, batch_cs)
|
2024-06-13 13:09:35 -07:00
|
|
|
|
for i, copy_out in safe_zip(batch_indices, copy_outs):
|
|
|
|
|
assert results[i] is None
|
|
|
|
|
results[i] = copy_out
|
|
|
|
|
return results
|
2022-09-26 16:17:26 -07:00
|
|
|
|
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
|
|
|
|
|
2024-02-28 15:21:50 -08:00
|
|
|
|
def _array_global_result_handler(global_aval, out_sharding, committed):
|
2025-01-20 15:12:12 -08:00
|
|
|
|
global_aval = core.update_aval_with_sharding(global_aval, out_sharding)
|
2022-08-29 22:02:32 -07:00
|
|
|
|
if global_aval.dtype == dtypes.float0:
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return lambda _: np.zeros(global_aval.shape, dtypes.float0)
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
|
2022-08-30 13:25:49 -07:00
|
|
|
|
return global_aval.dtype._rules.global_sharded_result_handler(
|
2024-02-28 15:21:50 -08:00
|
|
|
|
global_aval, out_sharding, committed)
|
2023-03-13 17:09:06 -07:00
|
|
|
|
return xc.array_result_handler(
|
|
|
|
|
global_aval, out_sharding, committed=committed, _skip_checks=True
|
|
|
|
|
)
|
2023-03-20 09:09:15 -07:00
|
|
|
|
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
|
2024-04-18 11:09:02 -07:00
|
|
|
|
|
2022-08-31 15:06:58 -07:00
|
|
|
|
# Only used for Arrays that come out of pmap.
|
2022-08-10 20:11:06 -07:00
|
|
|
|
def _array_local_result_handler(aval, sharding, indices):
|
2023-02-17 11:52:08 -08:00
|
|
|
|
if aval.dtype == dtypes.float0:
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return lambda _: np.zeros(aval.shape, dtypes.float0)
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
2022-08-30 13:25:49 -07:00
|
|
|
|
return aval.dtype._rules.local_sharded_result_handler(
|
|
|
|
|
aval, sharding, indices)
|
2023-03-13 17:09:06 -07:00
|
|
|
|
return xc.array_result_handler(
|
|
|
|
|
aval, sharding, committed=True, _skip_checks=True
|
|
|
|
|
)
|
2023-03-20 09:09:15 -07:00
|
|
|
|
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
|
2024-05-10 10:11:55 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Token handlers
|
|
|
|
|
|
2024-11-07 15:50:32 -08:00
|
|
|
|
def _token_shard_arg(xs, shardings, layouts, copy_semantics):
|
2025-01-16 11:23:39 -08:00
|
|
|
|
results = []
|
|
|
|
|
for x, sharding, layout in safe_zip(xs, shardings, layouts):
|
|
|
|
|
x.block_until_ready()
|
|
|
|
|
x = np.array([], dtype=bool)
|
|
|
|
|
results.append(api.device_put(x, Layout(layout, sharding)))
|
|
|
|
|
return results
|
2024-05-10 10:11:55 -07:00
|
|
|
|
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _token_global_result_handler(global_aval, out_sharding, committed):
|
|
|
|
|
array_handler = _array_global_result_handler(
|
2025-02-03 17:59:44 -08:00
|
|
|
|
core.get_token_aval(), out_sharding, committed)
|
2024-05-10 10:11:55 -07:00
|
|
|
|
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
out_buf = array_handler(*args, **kwargs)
|
|
|
|
|
return core.Token(out_buf)
|
|
|
|
|
return wrapper
|
|
|
|
|
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
|