rocm_jax/jax/_src/array.py
Yash Katariya 021fadfcbc Optimize accessing index and replica_id of
addressable_shards

Benchmark:

```
name                                 old time/op  new time/op  delta
bench_addressable_shards_index       53.0µs ± 2%   2.6µs ± 4%  -95.07%  (p=0.008 n=5+5)
bench_addressable_shards_replica_id  51.7µs ± 2%   2.6µs ± 2%  -94.92%  (p=0.008 n=5+5)
```

PiperOrigin-RevId: 517977244
2023-03-20 08:37:09 -07:00

734 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import operator as op
import numpy as np
import functools
from typing import (Sequence, Tuple, Callable, Union, Optional, cast, List, Set,
TYPE_CHECKING)
import jax
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import profiler
from jax._src.config import config
from jax._src.util import use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api
from jax._src.typing import ArrayLike
from jax.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map)
Shape = Tuple[int, ...]
Device = xc.Device
DeviceArray = xc.Buffer
Index = Tuple[slice, ...]
class Shard:
"""A single data shard of an Array.
Attributes:
device : Which device this shard resides on.
index : The index into the global array of this shard.
replica_id : Integer id indicating which replica of the global array this
shard is part of. Always 0 for fully sharded data
(i.e. when theres only 1 replica).
data : The data of this shard. None if ``device`` is non-local.
"""
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: Optional[ArrayImpl] = None):
self._device = device
self._sharding = sharding
self._global_shape = global_shape
self._data = data
def __repr__(self):
try:
return (f'Shard(device={repr(self.device)}, index={self.index}, '
f'replica_id={self.replica_id}, data={self.data})')
except ValueError:
return f'Shard(device={repr(self.device)}, data={self.data})'
@functools.cached_property
def index(self) -> Index:
try:
device_indices_map_fn = self._sharding.devices_indices_map
except AttributeError:
raise ValueError('Cannot calculate indices from sharding: '
f'{self._sharding}. Please create a device to index '
'mapping for your sharding.') from None
index = device_indices_map_fn(self._global_shape)[self.device]
assert index is not None
return index
@functools.cached_property
def replica_id(self) -> int:
return device_replica_id_map(self._sharding, self._global_shape)[self.device]
@property
def device(self):
return self._device
@property
def data(self):
return self._data
def _reconstruct_array(fun, args, arr_state, aval_state):
"""Method to reconstruct a device array from a serialized state."""
np_value = fun(*args)
np_value.__setstate__(arr_state)
jnp_value = api.device_put(np_value)
jnp_value.aval = jnp_value.aval.update(**aval_state)
return jnp_value
def _single_device_array_from_buf(buf, committed) -> ArrayImpl:
if isinstance(buf, ArrayImpl) and buf._committed == committed: # type: ignore
return buf
db = dispatch._set_aval(buf)
return ArrayImpl(db.aval, SingleDeviceSharding(db.device()), [db],
committed=committed, _skip_checks=True)
def _is_reduced_on_dim(idx):
# TODO(yashkatariya): This handles very narrow use case where we know XLA will
# not return an output with uneven sharding. Remove this after we have the
# ability to catch uneven shardings in lower_sharding_computation and raise
# a special exception for that which can be caught here to fallback to
# bouncing via host.
if not isinstance(idx, tuple):
idx = (idx,)
return all(isinstance(i, int) or
(isinstance(i, slice) and i == slice(None)) or
(isinstance(i, (np.ndarray, jax.Array)) and not i.shape and
np.issubdtype(i.dtype, np.integer))
for i in idx)
class ArrayImpl(basearray.Array):
# TODO(yashkatariya): Add __slots__ here.
aval: core.ShapedArray
_sharding: Sharding
_arrays: List[DeviceArray]
_committed: bool
_skip_checks: bool
_npy_value: Optional[np.ndarray]
@use_cpp_method()
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
arrays: Union[Sequence[DeviceArray], Sequence[ArrayImpl]],
committed: bool, _skip_checks: bool = False):
# NOTE: the actual implementation of the constructor is moved to C++.
self.aval = aval
self._sharding = sharding
# Extract DeviceArrays from arrays with `SingleDeviceSharding` to keep the
# code handling `self._arrays` simpler.
# TODO(yashkatariya): This will be slower as it will happen during
# `__init__` on single controller environment. Make it lazy.
self._arrays = [a if isinstance(a, DeviceArray) else a._arrays[0] for a in arrays]
# See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
# for what committed means.
self._committed = committed
self._npy_value = None
# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, xmap, etc).
if not _skip_checks or config.jax_enable_checks:
self._check_and_rearrange()
def _check_and_rearrange(self):
for db in self._arrays:
if db.dtype != self.dtype:
raise ValueError(
"Input buffers to `Array` must have matching dtypes. "
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")
device_id_to_buffer = {db.device().id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = set(d.id for d in addressable_dev)
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = set(array_device_ids) ^ set(addressable_device_ids)
if diff:
dev_in_sharding_not_in_arrays = set(addressable_device_ids) - set(array_device_ids)
dev_in_arrays_not_in_sharding = set(array_device_ids) - set(addressable_device_ids)
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
ss = self.sharding.shard_shape(self.shape)
for db in self._arrays:
if db.shape != ss:
raise ValueError(
f"Expected shard shape {ss} doesn't match the buffer "
f"shape {db.shape} for buffer: {db}")
# Rearrange arrays based on the device assignment.
if isinstance(self.sharding, XLACompatibleSharding):
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
@property
def shape(self) -> Shape:
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
@property
def ndim(self):
return len(self.shape)
@property
def size(self):
return math.prod(self.shape)
@property
def sharding(self):
return self._sharding
@property
def weak_type(self):
return self.aval.weak_type
def __str__(self):
return str(self._value)
def __len__(self):
try:
return self.shape[0]
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error
def __bool__(self):
return bool(self._value)
def __nonzero__(self):
return bool(self._value)
def __float__(self):
return self._value.__float__()
def __int__(self):
return self._value.__int__()
def __complex__(self):
return self._value.__complex__()
def __hex__(self):
assert self.ndim == 0, 'hex only works on scalar values'
return hex(self._value) # type: ignore
def __oct__(self):
assert self.ndim == 0, 'oct only works on scalar values'
return oct(self._value) # type: ignore
def __index__(self):
return op.index(self._value)
def tobytes(self, order="C"):
return self._value.tobytes(order)
def tolist(self):
return self._value.tolist()
def __format__(self, format_spec):
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
if self.ndim == 0:
return format(self._value[()], format_spec)
else:
return format(self._value, format_spec)
def __getitem__(self, idx):
from jax._src.numpy import lax_numpy
self._check_if_deleted()
if isinstance(self.sharding, PmapSharding):
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
if self._npy_value is None:
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
try:
arr_idx = indices.index(cidx)
except ValueError:
arr_idx = None
if arr_idx is not None:
arr = self._arrays[arr_idx]
return _single_device_array_from_buf(arr, committed=False)
return lax_numpy._rewriting_take(self, idx)
elif (dispatch.is_single_device_sharding(self.sharding) or
self.is_fully_replicated or _is_reduced_on_dim(idx)):
return lax_numpy._rewriting_take(self, idx)
else:
# TODO(yashkatariya): Don't bounce to host and use `_rewriting_take` or
# the fast path (see PmapSharding branch above) after after uneven
# partitioning support is added
return api.device_put(self._value[idx])
def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
assert self.is_fully_replicated or self.is_fully_addressable
if dispatch.is_single_device_sharding(self.sharding) or self.is_fully_replicated:
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore
elif isinstance(self.sharding, PmapSharding):
return (self[i] for i in range(self.shape[0])) # type: ignore
else:
# TODO(yashkatariya): Don't bounce to host and use `_chunk_iter` path
# here after uneven partitioning support is added.
return (api.device_put(self._value[i]) for i in range(self.shape[0]))
@property
def is_fully_replicated(self) -> bool:
return self.shape == self._arrays[0].shape
def __repr__(self):
prefix = 'Array('
if self.aval is not None and self.aval.weak_type:
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
else:
dtype_str = f'dtype={self.dtype.name})'
if self.is_fully_addressable or self.is_fully_replicated:
line_width = np.get_printoptions()["linewidth"]
s = np.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
last_line_len = len(s) - s.rfind('\n') + 1
sep = ' '
if last_line_len + len(dtype_str) + 1 > line_width:
sep = ' ' * len(prefix)
return f"{prefix}{s},{sep}{dtype_str}"
else:
return f"{prefix}{self.shape}, {dtype_str}"
@functools.cached_property
def is_fully_addressable(self) -> bool:
return self.sharding.is_fully_addressable
def __array__(self, dtype=None, context=None):
return np.asarray(self._value, dtype=dtype)
def __dlpack__(self):
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self)
def __reduce__(self):
fun, args, arr_state = self._value.__reduce__() # type: ignore
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state))
@use_cpp_method(xla_extension_version >= 138)
def unsafe_buffer_pointer(self):
if len(self._arrays) != 1:
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
" arrays.")
return self._arrays[0].unsafe_buffer_pointer()
@property
@use_cpp_method(xla_extension_version >= 138)
def __cuda_array_interface__(self):
if len(self._arrays) != 1:
raise ValueError("__cuda_array_interface__() is supported only for "
"unsharded arrays.")
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
@use_cpp_method(xla_extension_version >= 138)
def on_device_size_in_bytes(self):
"""Returns the total global on-device size of the array in bytes."""
arr = self._arrays[0]
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
return per_shard_size * len(self.sharding.device_set)
# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
self._check_if_deleted()
device_set = self.sharding.device_set
if len(device_set) == 1:
single_device, = device_set
return single_device
raise ValueError('Length of devices is greater than 1. '
'Please use `.devices()`.')
def devices(self) -> Set[Device]:
self._check_if_deleted()
return self.sharding.device_set
# TODO(https://github.com/google/jax/issues/12380): Remove this when DA is
# deleted.
@property
def device_buffer(self) -> ArrayImpl:
self._check_if_deleted()
if len(self._arrays) == 1:
return _single_device_array_from_buf(self._arrays[0], self._committed)
raise ValueError('Length of buffers is greater than 1. Please use '
'`.device_buffers` instead.')
# TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is
# deleted.
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
self._check_if_deleted()
return [_single_device_array_from_buf(a, self._committed)
for a in self._arrays]
def addressable_data(self, index: int) -> ArrayImpl:
self._check_if_deleted()
return _single_device_array_from_buf(self._arrays[index], self._committed)
@functools.cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
out = []
for db in self._arrays:
# Wrap the device arrays in `Array` until C++ returns an Array instead
# of a DA.
array = _single_device_array_from_buf(db, self._committed)
out.append(Shard(db.device(), self.sharding, self.shape, array))
return out
@property
def global_shards(self) -> Sequence[Shard]:
"""Returns list of all `Shard`s of the Array across all devices.
The result includes shards that are not addressable by the current process.
If a `Shard` is not addressable, then its `data` will be `None`.
"""
self._check_if_deleted()
if self.is_fully_addressable: # pylint: disable=using-constant-test
return self.addressable_shards
out = []
device_id_to_buffer = {db.device().id: db for db in self._arrays}
for global_d in self.sharding.device_set:
if device_id_to_buffer.get(global_d.id, None) is not None:
array = _single_device_array_from_buf(
device_id_to_buffer[global_d.id], self._committed)
else:
array = None
out.append(Shard(global_d, self.sharding, self.shape, array))
return out
@use_cpp_method(xla_extension_version >= 138)
def delete(self):
if self._arrays is None:
return
for buf in self._arrays:
buf.delete()
self._arrays = None
self._npy_value = None
@use_cpp_method()
def is_deleted(self):
if self._arrays is None:
return True
# This path is taken when a view of `Array` is created and the original
# Array is deleted. In that case, the buffers the view represents also get
# deleted.
return any(buf.is_deleted() for buf in self._arrays)
def _check_if_deleted(self):
if self.is_deleted():
raise RuntimeError("Array has been deleted.")
@use_cpp_method()
def block_until_ready(self):
self._check_if_deleted()
for db in self._arrays:
db.block_until_ready()
return self
@use_cpp_method(xla_extension_version >= 138)
def _single_device_array_to_np_array(self):
return np.asarray(self._arrays[0])
@use_cpp_method(xla_extension_version >= 138)
def _copy_single_device_array_to_host_async(self):
self._arrays[0].copy_to_host_async()
@profiler.annotate_function
def copy_to_host_async(self):
self._check_if_deleted()
if self._npy_value is None:
if self.is_fully_replicated:
self._copy_single_device_array_to_host_async()
return
# Only calculate the device_to_replica_id map once for performance
device_to_replica_id_map = (
device_replica_id_map(self.sharding, self.shape))
for arr in self._arrays:
if device_to_replica_id_map[arr.device()] == 0:
if isinstance(arr, ArrayImpl):
arr._copy_single_device_array_to_host_async()
else:
arr.copy_to_host_async()
@property
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
def _value(self) -> np.ndarray:
self._check_if_deleted()
if self._npy_value is None:
if self.is_fully_replicated:
self._npy_value = self._single_device_array_to_np_array() # type: ignore
self._npy_value.flags.writeable = False
return cast(np.ndarray, self._npy_value)
if not self.is_fully_addressable:
raise RuntimeError("Fetching value for `jax.Array` that spans "
"non-addressable devices is not possible. You can use "
"`jax.experimental.multihost_utils.process_allgather` "
"for this use case.")
# Only calculate the device_to_replica_id map once for performance
device_to_replica_id_map = device_replica_id_map(self.sharding, self.shape)
# device() is slow so compute it only once for the rest of the function.
devices = [arr.device() for arr in self._arrays]
for arr, d in zip(self._arrays, devices):
if device_to_replica_id_map[d] == 0:
if isinstance(arr, ArrayImpl):
arr._copy_single_device_array_to_host_async()
else:
arr.copy_to_host_async()
# Only calculate the device_to_index map once for performance
device_to_index_map = self.sharding.devices_indices_map(self.shape)
npy_value = np.empty(self.shape, self.dtype)
for arr, d in zip(self._arrays, devices):
if device_to_replica_id_map[d] == 0:
if isinstance(arr, ArrayImpl):
npy_value[device_to_index_map[d]] = (
arr._single_device_array_to_np_array())
else:
npy_value[device_to_index_map[d]] = np.asarray(arr)
self._npy_value = npy_value # type: ignore
self._npy_value.flags.writeable = False
# https://docs.python.org/3/library/typing.html#typing.cast
return cast(np.ndarray, self._npy_value)
# TODO(b/273265390): ideally we would write this as a decorator on the ArrayImpl
# class, however this triggers a pytype bug. Workaround: apply the decorator
# after the fact.
if not TYPE_CHECKING:
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)
# explicitly set to be unhashable. Same as what device_array.py does.
setattr(ArrayImpl, "__hash__", None)
setattr(ArrayImpl, "__array_priority__", 100)
def make_array_from_callback(
shape: Shape, sharding: Sharding,
data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl:
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
``data_callback`` is used to fetch the data for each addressable shard of the
returned ``jax.Array``.
Args:
shape : Shape of the ``jax.Array``.
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
laid out across devices.
data_callback : Callback that takes indices into the global array value as
input and returns the corresponding data of the global array value.
The data can be returned as any array-like object, e.g. a ``numpy.ndarray``.
Returns:
A ``jax.Array`` via data fetched from ``data_callback``.
Example:
>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
>>> def cb(index):
... return global_input_data[index]
...
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
>>> arr.addressable_data(0).shape
(4, 2)
"""
device_to_index_map = sharding.devices_indices_map(shape)
# Use addressable_devices here instead of `_addressable_device_assignment`
# because `_addressable_device_assignment` is only available on
# `XLACompatibleSharding` and this function is supposed to work for every
# `Sharding`.
arrays = [
api.device_put(data_callback(device_to_index_map[device]), device)
for device in sharding.addressable_devices
]
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
return ArrayImpl(aval, sharding, arrays, committed=True)
def make_array_from_single_device_arrays(
shape: Shape, sharding: Sharding, arrays: Sequence[basearray.Array]
) -> ArrayImpl:
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use
this function if you have already ``jax.device_put`` the value on a single
device and want to create a global Array. The smaller ``jax.Array``\s should be
addressable and belong to the current process.
Args:
shape : Shape of the ``jax.Array``.
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
laid out across devices.
arrays: Sequence of ``jax.Array``\s that are on a single device.
Returns:
A ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
Example:
>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> shape = (8, 8)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
>>> inp_data = np.arange(math.prod(shape)).reshape(shape)
...
>>> arrays = [
... jax.device_put(inp_data[index], d)
... for d, index in sharding.addressable_devices_indices_map(shape).items()]
...
>>> arr = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
>>> arr.addressable_data(0).shape
(4, 2)
"""
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
# TODO(phawkins): ideally the cast() could be checked. Revisit this after
# removing DeviceArray.
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval')
# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(ArrayImpl)
def _array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(val._value,
canonicalize_types=canonicalize_types)
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
def _array_shard_arg(x, devices, indices, sharding):
x._check_if_deleted()
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
if not x.is_fully_addressable:
if tuple(x_indices) == tuple(indices):
return x
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
if tuple(x_indices) == tuple(indices):
return xc.copy_array_to_devices_with_sharding(
x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
return pxla.shard_device_array(x, devices, indices, sharding)
else:
return pxla.shard_sharded_device_array_slow_path(
x, devices, indices, sharding)
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed,
is_out_sharding_from_xla):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
return xc.array_result_handler(
global_aval, out_sharding, committed=committed, _skip_checks=True
)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
# Only used for Arrays that come out of pmap.
def _array_local_result_handler(aval, sharding, indices):
if aval.dtype == dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
return xc.array_result_handler(
aval, sharding, committed=True, _skip_checks=True
)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler