2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2021 The JAX Authors.
|
2022-06-06 17:31:20 -07:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
import operator as op
|
2022-06-06 17:31:20 -07:00
|
|
|
|
import numpy as np
|
2022-09-19 16:58:46 -07:00
|
|
|
|
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-10 07:31:43 -07:00
|
|
|
|
from jax import core
|
2022-08-18 12:31:30 -07:00
|
|
|
|
from jax._src import abstract_arrays
|
2022-08-12 12:09:22 -07:00
|
|
|
|
from jax._src import ad_util
|
2022-07-27 10:54:54 -07:00
|
|
|
|
from jax._src import api_util
|
2022-09-23 09:59:46 -07:00
|
|
|
|
from jax._src import basearray
|
2022-06-24 10:04:31 -07:00
|
|
|
|
from jax._src import dispatch
|
2022-08-16 16:51:26 -07:00
|
|
|
|
from jax._src import dtypes
|
2022-08-12 12:09:22 -07:00
|
|
|
|
from jax._src.lax import lax as lax_internal
|
2022-06-06 17:31:20 -07:00
|
|
|
|
from jax._src.config import config
|
2022-08-10 20:11:06 -07:00
|
|
|
|
from jax._src.util import prod, safe_zip
|
2022-06-06 17:31:20 -07:00
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
|
from jax._src.api import device_put
|
2022-09-23 09:59:46 -07:00
|
|
|
|
from jax._src.typing import ArrayLike
|
2022-08-12 12:09:22 -07:00
|
|
|
|
from jax.interpreters import pxla, xla, mlir
|
2022-09-27 10:06:10 -07:00
|
|
|
|
from jax._src.sharding import (
|
2022-08-30 10:45:29 -07:00
|
|
|
|
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
|
|
|
|
device_replica_id_map)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
Shape = Tuple[int, ...]
|
|
|
|
|
Device = xc.Device
|
|
|
|
|
DeviceArray = xc.Buffer
|
|
|
|
|
Index = Tuple[slice, ...]
|
|
|
|
|
|
|
|
|
|
|
2022-06-06 18:44:45 -07:00
|
|
|
|
class Shard:
|
|
|
|
|
"""A single data shard of an Array.
|
|
|
|
|
|
2022-06-14 10:34:19 -07:00
|
|
|
|
Attributes:
|
2022-06-06 18:44:45 -07:00
|
|
|
|
device : Which device this shard resides on.
|
|
|
|
|
index : The index into the global array of this shard.
|
|
|
|
|
replica_id : Integer id indicating which replica of the global array this
|
|
|
|
|
shard is part of. Always 0 for fully sharded data
|
|
|
|
|
(i.e. when there’s only 1 replica).
|
|
|
|
|
data : The data of this shard. None if ``device`` is non-local.
|
|
|
|
|
"""
|
2022-06-14 10:34:19 -07:00
|
|
|
|
|
|
|
|
|
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
|
2022-09-26 16:17:26 -07:00
|
|
|
|
data: Optional[ArrayImpl] = None):
|
2022-06-14 10:34:19 -07:00
|
|
|
|
self.device = device
|
|
|
|
|
self._sharding = sharding
|
|
|
|
|
self._global_shape = global_shape
|
|
|
|
|
self.data = data
|
|
|
|
|
|
2022-06-22 02:25:34 -07:00
|
|
|
|
def __repr__(self):
|
|
|
|
|
try:
|
|
|
|
|
return (f'Shard(device={repr(self.device)}, index={self.index}, '
|
|
|
|
|
f'replica_id={self.replica_id}, data={self.data})')
|
|
|
|
|
except ValueError:
|
|
|
|
|
return f'Shard(device={repr(self.device)}, data={self.data})'
|
|
|
|
|
|
2022-06-14 10:34:19 -07:00
|
|
|
|
@property
|
|
|
|
|
def index(self) -> Index:
|
2022-09-18 15:35:18 -07:00
|
|
|
|
try:
|
|
|
|
|
device_indices_map_fn = self._sharding.devices_indices_map
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError('Cannot calculate indices from sharding: '
|
|
|
|
|
f'{self._sharding}. Please create a device to index '
|
|
|
|
|
'mapping for your sharding.') from None
|
|
|
|
|
index = device_indices_map_fn(self._global_shape)[self.device]
|
2022-06-14 10:34:19 -07:00
|
|
|
|
assert index is not None
|
|
|
|
|
return index
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def replica_id(self) -> int:
|
2022-08-29 14:49:17 -07:00
|
|
|
|
return device_replica_id_map(self._sharding, self._global_shape)[self.device]
|
2022-06-06 18:44:45 -07:00
|
|
|
|
|
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def _reconstruct_array(fun, args, arr_state, aval_state):
|
|
|
|
|
"""Method to reconstruct a device array from a serialized state."""
|
|
|
|
|
np_value = fun(*args)
|
|
|
|
|
np_value.__setstate__(arr_state)
|
|
|
|
|
jnp_value = device_put(np_value)
|
|
|
|
|
jnp_value.aval = jnp_value.aval.update(**aval_state)
|
|
|
|
|
return jnp_value
|
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
|
|
2022-09-23 13:29:47 -07:00
|
|
|
|
def _single_device_array_from_buf(buf, committed):
|
|
|
|
|
db = pxla._set_aval(buf)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
return ArrayImpl(db.aval, SingleDeviceSharding(db.device()), [db],
|
|
|
|
|
committed=committed, _skip_checks=True)
|
2022-09-23 13:29:47 -07:00
|
|
|
|
|
|
|
|
|
|
2022-09-30 09:55:25 -07:00
|
|
|
|
@pxla.use_cpp_class(xc.ArrayImpl if xc._version >= 97 else None)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
class ArrayImpl(basearray.Array):
|
2022-06-06 17:31:20 -07:00
|
|
|
|
# TODO(yashkatariya): Add __slots__ here.
|
|
|
|
|
|
2022-09-21 12:51:32 -07:00
|
|
|
|
aval: core.ShapedArray
|
|
|
|
|
_sharding: Sharding
|
|
|
|
|
_arrays: List[DeviceArray]
|
|
|
|
|
_committed: bool
|
|
|
|
|
_skip_checks: bool
|
|
|
|
|
_npy_value: Optional[np.ndarray]
|
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
|
@pxla.use_cpp_method
|
2022-08-17 12:25:14 -07:00
|
|
|
|
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
|
2022-09-26 16:17:26 -07:00
|
|
|
|
arrays: Union[Sequence[DeviceArray], Sequence[ArrayImpl]],
|
2022-09-18 15:35:18 -07:00
|
|
|
|
committed: bool, _skip_checks: bool = False):
|
2022-09-08 13:47:57 -07:00
|
|
|
|
# NOTE: the actual implementation of the constructor is moved to C++.
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
self.aval = aval
|
2022-06-06 17:31:20 -07:00
|
|
|
|
self._sharding = sharding
|
2022-06-22 09:20:26 -07:00
|
|
|
|
# Extract DeviceArrays from arrays with `SingleDeviceSharding` to keep the
|
|
|
|
|
# code handling `self._arrays` simpler.
|
|
|
|
|
# TODO(yashkatariya): This will be slower as it will happen during
|
|
|
|
|
# `__init__` on single controller environment. Make it lazy.
|
2022-09-21 12:51:32 -07:00
|
|
|
|
self._arrays = [a if isinstance(a, DeviceArray) else a._arrays[0] for a in arrays]
|
2022-06-06 17:31:20 -07:00
|
|
|
|
# See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
|
|
|
|
|
# for what committed means.
|
|
|
|
|
self._committed = committed
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._npy_value = None
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-08-23 10:19:59 -07:00
|
|
|
|
# Don't rearrange if skip_checks is enabled because this assumes that the
|
|
|
|
|
# input buffers are already arranged properly. This usually happens when
|
|
|
|
|
# Array's are created as output of a JAX transformation
|
|
|
|
|
# (like pjit, xmap, etc).
|
2022-09-30 09:55:25 -07:00
|
|
|
|
if not _skip_checks or config.jax_enable_checks:
|
|
|
|
|
self._check_and_rearrange()
|
|
|
|
|
|
|
|
|
|
def _check_and_rearrange(self):
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
if db.dtype != self.dtype:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input buffers to `Array` must have matching dtypes. "
|
|
|
|
|
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")
|
|
|
|
|
|
|
|
|
|
device_id_to_buffer = {db.device().id: db for db in self._arrays}
|
|
|
|
|
|
|
|
|
|
addressable_dev = self.sharding.addressable_devices
|
|
|
|
|
if len(self._arrays) != len(addressable_dev):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected {len(addressable_dev)} per-device arrays "
|
|
|
|
|
"(this is how many devices are addressable by the sharding), but "
|
|
|
|
|
f"got {len(self._arrays)}")
|
|
|
|
|
|
|
|
|
|
array_device_ids = set(device_id_to_buffer.keys())
|
|
|
|
|
addressable_device_ids = set(d.id for d in addressable_dev)
|
|
|
|
|
# Calculate a symmetric difference because the device ids between sharding
|
|
|
|
|
# and _arrays should match.
|
|
|
|
|
diff = set(array_device_ids) ^ set(addressable_device_ids)
|
|
|
|
|
if diff:
|
|
|
|
|
dev_in_sharding_not_in_arrays = set(addressable_device_ids) - set(array_device_ids)
|
|
|
|
|
dev_in_arrays_not_in_sharding = set(array_device_ids) - set(addressable_device_ids)
|
|
|
|
|
err_msg = (
|
|
|
|
|
"Addressable devices and per-device arrays devices do not match.")
|
|
|
|
|
if dev_in_sharding_not_in_arrays:
|
|
|
|
|
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
|
|
|
|
"that are not present in per-device arrays.")
|
|
|
|
|
if dev_in_arrays_not_in_sharding:
|
|
|
|
|
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
|
|
|
|
"that are not present in the sharding.")
|
|
|
|
|
raise ValueError(err_msg)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-09-08 13:47:57 -07:00
|
|
|
|
ss = self.sharding.shard_shape(self.shape)
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
if db.shape != ss:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected shard shape {ss} doesn't match the buffer "
|
|
|
|
|
f"shape {db.shape} for buffer: {db}")
|
|
|
|
|
|
|
|
|
|
# Rearrange arrays based on the device assignment.
|
|
|
|
|
if isinstance(self.sharding, XLACompatibleSharding):
|
2022-09-18 15:35:18 -07:00
|
|
|
|
addressable_da = self.sharding._addressable_device_assignment
|
2022-09-30 09:55:25 -07:00
|
|
|
|
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
2022-09-08 13:47:57 -07:00
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
@property
|
|
|
|
|
def shape(self) -> Shape:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return self.aval.shape
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
@property
|
2022-08-17 12:25:14 -07:00
|
|
|
|
def dtype(self):
|
|
|
|
|
return self.aval.dtype
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
@property
|
|
|
|
|
def ndim(self):
|
|
|
|
|
return len(self.shape)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def size(self):
|
|
|
|
|
return prod(self.shape)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def sharding(self):
|
|
|
|
|
return self._sharding
|
|
|
|
|
|
2022-08-16 16:51:26 -07:00
|
|
|
|
def __str__(self):
|
|
|
|
|
return str(self._value)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
try:
|
|
|
|
|
return self.shape[0]
|
|
|
|
|
except IndexError as err:
|
|
|
|
|
raise TypeError("len() of unsized object") from err # same as numpy error
|
|
|
|
|
|
|
|
|
|
def __bool__(self):
|
|
|
|
|
return bool(self._value)
|
|
|
|
|
|
|
|
|
|
def __nonzero__(self):
|
|
|
|
|
return bool(self._value)
|
|
|
|
|
|
|
|
|
|
def __float__(self):
|
|
|
|
|
return self._value.__float__()
|
|
|
|
|
|
|
|
|
|
def __int__(self):
|
|
|
|
|
return self._value.__int__()
|
|
|
|
|
|
|
|
|
|
def __complex__(self):
|
|
|
|
|
return self._value.__complex__()
|
|
|
|
|
|
|
|
|
|
def __hex__(self):
|
|
|
|
|
assert self.ndim == 0, 'hex only works on scalar values'
|
|
|
|
|
return hex(self._value) # type: ignore
|
|
|
|
|
|
|
|
|
|
def __oct__(self):
|
|
|
|
|
assert self.ndim == 0, 'oct only works on scalar values'
|
|
|
|
|
return oct(self._value) # type: ignore
|
|
|
|
|
|
|
|
|
|
def __index__(self):
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return op.index(self._value)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def tobytes(self, order="C"):
|
2022-08-16 16:51:26 -07:00
|
|
|
|
return self._value.tobytes(order)
|
|
|
|
|
|
|
|
|
|
def tolist(self):
|
|
|
|
|
return self._value.tolist()
|
|
|
|
|
|
|
|
|
|
def __format__(self, format_spec):
|
|
|
|
|
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
|
|
|
|
|
if self.ndim == 0:
|
|
|
|
|
return format(self._value[()], format_spec)
|
|
|
|
|
else:
|
|
|
|
|
return format(self._value, format_spec)
|
|
|
|
|
|
2022-09-09 14:24:39 -07:00
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
from jax._src.numpy import lax_numpy
|
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
|
2022-09-09 20:41:12 -07:00
|
|
|
|
if dispatch.is_single_device_sharding(self.sharding):
|
|
|
|
|
return lax_numpy._rewriting_take(self, idx)
|
2022-09-09 14:24:39 -07:00
|
|
|
|
# TODO(yashkatariya): Make it work for other Shardings too wherever its
|
|
|
|
|
# possible to not do data movement.
|
2022-09-09 20:41:12 -07:00
|
|
|
|
elif isinstance(self.sharding, PmapSharding):
|
|
|
|
|
if not isinstance(idx, tuple):
|
|
|
|
|
cidx = (idx,) + (slice(None),) * (len(self.shape) - 1)
|
2022-09-09 14:24:39 -07:00
|
|
|
|
else:
|
2022-09-09 20:41:12 -07:00
|
|
|
|
cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
|
|
|
|
|
if self._npy_value is None:
|
2022-09-18 15:35:18 -07:00
|
|
|
|
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
|
2022-09-09 20:41:12 -07:00
|
|
|
|
try:
|
|
|
|
|
buf_idx = indices.index(cidx)
|
|
|
|
|
except ValueError:
|
|
|
|
|
buf_idx = None
|
|
|
|
|
if buf_idx is not None:
|
|
|
|
|
buf = self._arrays[buf_idx]
|
2022-09-23 13:29:47 -07:00
|
|
|
|
aval = core.ShapedArray(buf.shape, self.dtype)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
|
|
|
|
|
committed=False, _skip_checks=True)
|
2022-09-09 20:41:12 -07:00
|
|
|
|
return lax_numpy._rewriting_take(self, idx)
|
|
|
|
|
else:
|
|
|
|
|
# TODO(yashkatariya): Don't bounce to host and use `_rewriting_take` or
|
2022-10-03 22:28:26 -07:00
|
|
|
|
# the fast path (see PmapSharding branch above) after after uneven
|
|
|
|
|
# partitioning support is added
|
|
|
|
|
return device_put(self._value[idx])
|
2022-09-09 14:24:39 -07:00
|
|
|
|
|
2022-08-16 16:51:26 -07:00
|
|
|
|
def __iter__(self):
|
|
|
|
|
if self.ndim == 0:
|
|
|
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
|
|
|
else:
|
2022-10-08 19:23:32 -07:00
|
|
|
|
assert self.is_fully_replicated or self.is_fully_addressable
|
2022-09-09 20:41:12 -07:00
|
|
|
|
if dispatch.is_single_device_sharding(self.sharding):
|
|
|
|
|
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore
|
|
|
|
|
elif isinstance(self.sharding, PmapSharding):
|
2022-09-09 14:24:39 -07:00
|
|
|
|
return (self[i] for i in range(self.shape[0])) # type: ignore
|
2022-09-01 19:53:58 -07:00
|
|
|
|
else:
|
2022-09-09 20:41:12 -07:00
|
|
|
|
# TODO(yashkatariya): Don't bounce to host and use `_chunk_iter` path
|
2022-10-03 22:28:26 -07:00
|
|
|
|
# here after uneven partitioning support is added.
|
|
|
|
|
return (device_put(self._value[i]) for i in range(self.shape[0]))
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
|
|
|
|
def item(self):
|
|
|
|
|
if dtypes.issubdtype(self.dtype, np.complexfloating):
|
|
|
|
|
return complex(self)
|
|
|
|
|
elif dtypes.issubdtype(self.dtype, np.floating):
|
|
|
|
|
return float(self)
|
|
|
|
|
elif dtypes.issubdtype(self.dtype, np.integer):
|
|
|
|
|
return int(self)
|
|
|
|
|
elif dtypes.issubdtype(self.dtype, np.bool_):
|
|
|
|
|
return bool(self)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(self.dtype)
|
|
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
|
@property
|
2022-08-24 20:41:48 -07:00
|
|
|
|
def is_fully_replicated(self) -> bool:
|
|
|
|
|
return self.shape == self._arrays[0].shape
|
|
|
|
|
|
2022-06-17 13:11:52 -07:00
|
|
|
|
def __repr__(self):
|
2022-09-28 08:57:07 -07:00
|
|
|
|
prefix = 'Array('
|
2022-08-17 12:25:14 -07:00
|
|
|
|
if self.aval is not None and self.aval.weak_type:
|
|
|
|
|
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
|
|
|
|
|
else:
|
|
|
|
|
dtype_str = f'dtype={self.dtype.name})'
|
2022-06-17 13:11:52 -07:00
|
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
|
if self.is_fully_addressable or self.is_fully_replicated:
|
2022-06-17 13:11:52 -07:00
|
|
|
|
line_width = np.get_printoptions()["linewidth"]
|
|
|
|
|
s = np.array2string(self._value, prefix=prefix, suffix=',',
|
|
|
|
|
separator=', ', max_line_width=line_width)
|
|
|
|
|
last_line_len = len(s) - s.rfind('\n') + 1
|
|
|
|
|
sep = ' '
|
|
|
|
|
if last_line_len + len(dtype_str) + 1 > line_width:
|
|
|
|
|
sep = ' ' * len(prefix)
|
|
|
|
|
return f"{prefix}{s},{sep}{dtype_str}"
|
|
|
|
|
else:
|
2022-08-23 19:48:59 -07:00
|
|
|
|
return f"{prefix}{self.shape}, {dtype_str}"
|
2022-06-17 13:11:52 -07:00
|
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
|
@pxla.maybe_cached_property
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def is_fully_addressable(self) -> bool:
|
2022-10-08 19:23:32 -07:00
|
|
|
|
return self.sharding.is_fully_addressable
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-08-24 18:27:40 -07:00
|
|
|
|
def __array__(self, dtype=None, context=None):
|
2022-06-13 18:07:55 -07:00
|
|
|
|
return np.asarray(self._value, dtype=dtype)
|
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def __dlpack__(self):
|
|
|
|
|
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
|
|
|
|
return to_dlpack(self)
|
|
|
|
|
|
|
|
|
|
def __reduce__(self):
|
2022-08-29 09:00:03 -07:00
|
|
|
|
fun, args, arr_state = self._value.__reduce__() # type: ignore
|
2022-08-18 15:58:40 -07:00
|
|
|
|
aval_state = {'weak_type': self.aval.weak_type,
|
|
|
|
|
'named_shape': self.aval.named_shape}
|
|
|
|
|
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
|
|
|
|
|
2022-08-29 22:02:32 -07:00
|
|
|
|
def unsafe_buffer_pointer(self):
|
|
|
|
|
assert len(self._arrays) == 1
|
|
|
|
|
return self._arrays[0].unsafe_buffer_pointer()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def __cuda_array_interface__(self):
|
|
|
|
|
assert len(self._arrays) == 1
|
|
|
|
|
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
# TODO(yashkatariya): Remove this method when everyone is using devices().
|
|
|
|
|
def device(self) -> Device:
|
2022-08-18 15:58:40 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-08-17 12:25:14 -07:00
|
|
|
|
device_set = self.sharding.device_set
|
|
|
|
|
if len(device_set) == 1:
|
|
|
|
|
single_device, = device_set
|
|
|
|
|
return single_device
|
|
|
|
|
raise ValueError('Length of devices is greater than 1. '
|
|
|
|
|
'Please use `.devices()`.')
|
|
|
|
|
|
|
|
|
|
def devices(self) -> List[Device]:
|
2022-08-18 15:58:40 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return list(self.sharding.device_set)
|
|
|
|
|
|
2022-09-15 13:26:57 -07:00
|
|
|
|
# TODO(https://github.com/google/jax/issues/12380): Remove this when DA is
|
|
|
|
|
# deleted.
|
|
|
|
|
@property
|
2022-10-10 14:44:28 -07:00
|
|
|
|
def device_buffer(self) -> ArrayImpl:
|
2022-09-15 13:26:57 -07:00
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
if len(self._arrays) == 1:
|
2022-09-23 13:29:47 -07:00
|
|
|
|
return _single_device_array_from_buf(self._arrays[0], self._committed)
|
2022-09-15 13:26:57 -07:00
|
|
|
|
raise ValueError('Length of buffers is greater than 1. Please use '
|
|
|
|
|
'`.device_buffers` instead.')
|
|
|
|
|
|
|
|
|
|
# TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is
|
|
|
|
|
# deleted.
|
|
|
|
|
@property
|
2022-10-10 14:44:28 -07:00
|
|
|
|
def device_buffers(self) -> Sequence[ArrayImpl]:
|
2022-09-15 13:26:57 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-09-23 13:29:47 -07:00
|
|
|
|
return [_single_device_array_from_buf(a, self._committed)
|
|
|
|
|
for a in self._arrays]
|
2022-09-15 13:26:57 -07:00
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
def addressable_data(self, index: int) -> ArrayImpl:
|
2022-09-21 18:18:57 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-09-23 13:29:47 -07:00
|
|
|
|
return _single_device_array_from_buf(self._arrays[index], self._committed)
|
2022-09-21 18:18:57 -07:00
|
|
|
|
|
2022-06-06 18:44:45 -07:00
|
|
|
|
@pxla.maybe_cached_property
|
|
|
|
|
def addressable_shards(self) -> Sequence[Shard]:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-06-06 18:44:45 -07:00
|
|
|
|
out = []
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
# Wrap the device arrays in `Array` until C++ returns an Array instead
|
|
|
|
|
# of a DA.
|
2022-09-23 13:29:47 -07:00
|
|
|
|
array = _single_device_array_from_buf(db, self._committed)
|
|
|
|
|
out.append(Shard(db.device(), self.sharding, self.shape, array))
|
2022-06-06 18:44:45 -07:00
|
|
|
|
return out
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def delete(self):
|
|
|
|
|
if self._arrays is None:
|
|
|
|
|
return
|
|
|
|
|
for buf in self._arrays:
|
|
|
|
|
buf.delete()
|
|
|
|
|
self._arrays = None
|
|
|
|
|
self._npy_value = None
|
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def is_deleted(self):
|
2022-09-08 14:39:12 -07:00
|
|
|
|
if self._arrays is None:
|
|
|
|
|
return True
|
|
|
|
|
# This path is taken when a view of `Array` is created and the original
|
|
|
|
|
# Array is deleted. In that case, the buffers the view represents also get
|
|
|
|
|
# deleted.
|
|
|
|
|
return any(buf.is_deleted() for buf in self._arrays)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def _check_if_deleted(self):
|
|
|
|
|
if self._arrays is None:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
raise RuntimeError("Array has been deleted.")
|
2022-06-13 18:07:55 -07:00
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
|
@pxla.use_cpp_method
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def block_until_ready(self):
|
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
db.block_until_ready()
|
|
|
|
|
return self
|
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def copy_to_host_async(self):
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
if self._npy_value is None:
|
2022-06-14 10:34:19 -07:00
|
|
|
|
try:
|
|
|
|
|
self.addressable_shards[0].replica_id
|
|
|
|
|
replica_id_exists = True
|
|
|
|
|
except ValueError:
|
|
|
|
|
replica_id_exists = False
|
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
for s in self.addressable_shards:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
if not replica_id_exists or s.replica_id == 0:
|
2022-06-14 10:34:19 -07:00
|
|
|
|
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
@property
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def _value(self) -> np.ndarray:
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-08-23 19:48:59 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
if self._npy_value is None:
|
2022-10-08 19:23:32 -07:00
|
|
|
|
if self.is_fully_replicated:
|
2022-08-24 20:41:48 -07:00
|
|
|
|
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
|
2022-09-15 13:26:57 -07:00
|
|
|
|
self._npy_value.flags.writeable = False
|
2022-08-24 20:41:48 -07:00
|
|
|
|
return cast(np.ndarray, self._npy_value)
|
2022-08-23 19:48:59 -07:00
|
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
|
if not self.is_fully_addressable:
|
2022-08-23 19:48:59 -07:00
|
|
|
|
raise RuntimeError("Fetching value for `jax.Array` that spans "
|
|
|
|
|
"non-addressable devices is not possible. You can use "
|
|
|
|
|
"`jax.experimental.multihost_utils.process_allgather` "
|
|
|
|
|
"for this use case.")
|
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self.copy_to_host_async()
|
|
|
|
|
npy_value = np.empty(self.shape, self.dtype)
|
2022-06-14 10:34:19 -07:00
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
self.addressable_shards[0].replica_id
|
|
|
|
|
replica_id_exists = True
|
|
|
|
|
except ValueError:
|
|
|
|
|
replica_id_exists = False
|
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
for s in self.addressable_shards:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
if not replica_id_exists or s.replica_id == 0:
|
2022-08-25 07:27:54 -07:00
|
|
|
|
npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr]
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._npy_value = npy_value # type: ignore
|
2022-09-15 13:26:57 -07:00
|
|
|
|
self._npy_value.flags.writeable = False
|
2022-06-13 18:07:55 -07:00
|
|
|
|
# https://docs.python.org/3/library/typing.html#typing.cast
|
|
|
|
|
return cast(np.ndarray, self._npy_value)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
# explicitly set to be unhashable. Same as what device_array.py does.
|
2022-09-26 16:17:26 -07:00
|
|
|
|
setattr(ArrayImpl, "__hash__", None)
|
|
|
|
|
setattr(ArrayImpl, "__array_priority__", 100)
|
2022-08-31 15:06:58 -07:00
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
def make_array_from_callback(
|
|
|
|
|
shape: Shape, sharding: Sharding,
|
|
|
|
|
data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl:
|
2022-08-30 21:56:39 -07:00
|
|
|
|
device_to_index_map = sharding.devices_indices_map(shape)
|
2022-08-31 15:06:58 -07:00
|
|
|
|
# Use addressable_devices here instead of `_addressable_device_assignment`
|
|
|
|
|
# because `_addressable_device_assignment` is only available on
|
|
|
|
|
# `XLACompatibleSharding` and this function is supposed to work for every
|
|
|
|
|
# `Sharding`.
|
2022-06-22 09:20:26 -07:00
|
|
|
|
arrays = [
|
2022-08-30 21:56:39 -07:00
|
|
|
|
device_put(data_callback(device_to_index_map[device]), device)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
for device in sharding.addressable_devices
|
|
|
|
|
]
|
2022-08-17 12:25:14 -07:00
|
|
|
|
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
return ArrayImpl(aval, sharding, arrays, committed=True)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
def make_array_from_single_device_arrays(
|
|
|
|
|
shape: Shape, sharding: Sharding, arrays: Sequence[ArrayImpl]) -> ArrayImpl:
|
2022-09-26 12:43:13 -07:00
|
|
|
|
# All input arrays should be committed. Checking it is expensive on
|
|
|
|
|
# single-controller systems.
|
|
|
|
|
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
return ArrayImpl(aval, sharding, arrays, committed=True)
|
2022-09-26 12:43:13 -07:00
|
|
|
|
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
|
|
|
|
|
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
|
|
|
|
|
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
|
|
|
|
api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval')
|
|
|
|
|
ad_util.jaxval_adders[ArrayImpl] = lax_internal.add
|
|
|
|
|
ad_util.jaxval_zeros_likers[ArrayImpl] = lax_internal.zeros_like_array
|
|
|
|
|
if xc._version >= 96:
|
2022-09-23 09:59:46 -07:00
|
|
|
|
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
2022-09-26 16:17:26 -07:00
|
|
|
|
basearray.Array.register(ArrayImpl)
|
2022-08-12 12:09:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _array_mlir_constant_handler(val, canonicalize_types=True):
|
|
|
|
|
return mlir.ir_constants(val._value,
|
|
|
|
|
canonicalize_types=canonicalize_types)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
|
|
|
|
def _device_put_array(x, device: Optional[Device]):
|
2022-09-08 08:49:12 -07:00
|
|
|
|
if dispatch.is_single_device_sharding(x.sharding):
|
2022-06-28 12:48:39 -07:00
|
|
|
|
x = dispatch._copy_device_array_to_device(pxla._set_aval(x._arrays[0]), device)
|
|
|
|
|
return (x,)
|
|
|
|
|
else:
|
|
|
|
|
# Round trip via host if x is sharded. SDA also does a round trip via host.
|
|
|
|
|
return dispatch._device_put_array(x._value, device)
|
|
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
|
dispatch.device_put_handlers[ArrayImpl] = _device_put_array
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
|
|
|
|
|
2022-08-19 21:36:43 -07:00
|
|
|
|
def _array_pmap_shard_arg(x, devices, indices, mode):
|
2022-09-08 08:49:12 -07:00
|
|
|
|
if dispatch.is_single_device_sharding(x.sharding):
|
2022-08-19 21:36:43 -07:00
|
|
|
|
return pxla._shard_device_array(x, devices, indices, mode)
|
|
|
|
|
|
|
|
|
|
# If the sharding of Array does not match pmap's sharding then take the slow
|
|
|
|
|
# path which is similar to what SDA does. This slow path reroute only happens
|
|
|
|
|
# for `pmap`.
|
2022-09-19 16:58:46 -07:00
|
|
|
|
x_indices = tuple(x.sharding.addressable_devices_indices_map(x.shape).values())
|
2022-08-31 15:06:58 -07:00
|
|
|
|
if indices == x_indices:
|
2022-08-10 20:11:06 -07:00
|
|
|
|
return [buf if buf.device() == d else buf.copy_to_device(d)
|
|
|
|
|
for buf, d in safe_zip(x._arrays, devices)]
|
2022-08-19 21:36:43 -07:00
|
|
|
|
else:
|
|
|
|
|
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
|
|
|
|
|
|
|
|
|
|
|
2022-10-07 16:48:34 -07:00
|
|
|
|
def _array_rest_shard_arg(x: ArrayImpl, devices, indices, mode):
|
|
|
|
|
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
|
2022-10-08 19:23:32 -07:00
|
|
|
|
if not x.is_fully_addressable:
|
2022-10-08 11:39:05 -07:00
|
|
|
|
if tuple(x_indices) == tuple(indices):
|
|
|
|
|
return x._arrays
|
|
|
|
|
else:
|
|
|
|
|
return NotImplementedError("Cannot reshard an input that is not fully "
|
|
|
|
|
"addressable")
|
2022-09-15 10:33:31 -07:00
|
|
|
|
else:
|
2022-10-08 11:39:05 -07:00
|
|
|
|
if tuple(x_indices) == tuple(indices):
|
|
|
|
|
return [buf if buf.device() == d else buf.copy_to_device(d)
|
|
|
|
|
for buf, d in safe_zip(x._arrays, devices)]
|
|
|
|
|
# Resharding starts here:
|
|
|
|
|
if isinstance(x.sharding, PmapSharding):
|
|
|
|
|
return pxla.device_put(x._value, devices, replicate=True)
|
2022-09-15 10:33:31 -07:00
|
|
|
|
if dispatch.is_single_device_sharding(x.sharding):
|
|
|
|
|
return pxla._shard_device_array(x, devices, indices, mode)
|
|
|
|
|
else:
|
2022-10-08 11:39:05 -07:00
|
|
|
|
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)
|
2022-09-15 10:33:31 -07:00
|
|
|
|
|
|
|
|
|
|
2022-08-19 21:36:43 -07:00
|
|
|
|
def _array_shard_arg(x, devices, indices, mode):
|
|
|
|
|
if mode == pxla.InputsHandlerMode.pmap:
|
|
|
|
|
return _array_pmap_shard_arg(x, devices, indices, mode)
|
2022-08-10 20:11:06 -07:00
|
|
|
|
else:
|
2022-09-15 10:33:31 -07:00
|
|
|
|
return _array_rest_shard_arg(x, devices, indices, mode)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
|
|
|
|
|
2022-08-31 22:53:32 -07:00
|
|
|
|
def _array_global_result_handler(global_aval, out_sharding, committed,
|
|
|
|
|
is_out_sharding_from_xla):
|
2022-08-29 22:02:32 -07:00
|
|
|
|
if global_aval.dtype == dtypes.float0:
|
|
|
|
|
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
|
2022-08-30 14:47:15 -07:00
|
|
|
|
if core.is_opaque_dtype(global_aval.dtype):
|
2022-08-30 13:25:49 -07:00
|
|
|
|
return global_aval.dtype._rules.global_sharded_result_handler(
|
2022-08-31 22:53:32 -07:00
|
|
|
|
global_aval, out_sharding, committed, is_out_sharding_from_xla)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs,
|
|
|
|
|
committed=committed, _skip_checks=True)
|
2022-08-10 20:11:06 -07:00
|
|
|
|
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
|
|
|
|
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
2022-08-29 22:02:32 -07:00
|
|
|
|
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
|
2022-08-10 20:11:06 -07:00
|
|
|
|
|
|
|
|
|
|
2022-08-31 15:06:58 -07:00
|
|
|
|
# Only used for Arrays that come out of pmap.
|
2022-08-10 20:11:06 -07:00
|
|
|
|
def _array_local_result_handler(aval, sharding, indices):
|
2022-08-30 14:47:15 -07:00
|
|
|
|
if core.is_opaque_dtype(aval.dtype):
|
2022-08-30 13:25:49 -07:00
|
|
|
|
return aval.dtype._rules.local_sharded_result_handler(
|
|
|
|
|
aval, sharding, indices)
|
2022-09-26 16:17:26 -07:00
|
|
|
|
return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True,
|
|
|
|
|
_skip_checks=True)
|
2022-08-10 20:11:06 -07:00
|
|
|
|
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
|
|
|
|
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
|