2022-06-06 17:31:20 -07:00
|
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
import operator as op
|
2022-06-06 17:31:20 -07:00
|
|
|
|
import numpy as np
|
2022-06-22 09:20:26 -07:00
|
|
|
|
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-10 07:31:43 -07:00
|
|
|
|
from jax import core
|
2022-08-12 12:09:22 -07:00
|
|
|
|
from jax._src import ad_util
|
2022-07-27 10:54:54 -07:00
|
|
|
|
from jax._src import api_util
|
2022-06-24 10:04:31 -07:00
|
|
|
|
from jax._src import dispatch
|
2022-08-16 16:51:26 -07:00
|
|
|
|
from jax._src import dtypes
|
2022-08-12 12:09:22 -07:00
|
|
|
|
from jax._src.lax import lax as lax_internal
|
2022-06-06 17:31:20 -07:00
|
|
|
|
from jax._src.config import config
|
2022-08-10 20:11:06 -07:00
|
|
|
|
from jax._src.util import prod, safe_zip
|
2022-06-06 17:31:20 -07:00
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
|
from jax._src.api import device_put
|
2022-08-12 12:09:22 -07:00
|
|
|
|
from jax.interpreters import pxla, xla, mlir
|
2022-06-10 07:31:43 -07:00
|
|
|
|
from jax.experimental.sharding import (Sharding, SingleDeviceSharding,
|
2022-06-22 09:20:26 -07:00
|
|
|
|
XLACompatibleSharding)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
Shape = Tuple[int, ...]
|
|
|
|
|
Device = xc.Device
|
|
|
|
|
DeviceArray = xc.Buffer
|
|
|
|
|
Index = Tuple[slice, ...]
|
|
|
|
|
ArrayLike = Union[np.ndarray, DeviceArray]
|
|
|
|
|
|
|
|
|
|
|
2022-06-06 18:44:45 -07:00
|
|
|
|
class Shard:
|
|
|
|
|
"""A single data shard of an Array.
|
|
|
|
|
|
2022-06-14 10:34:19 -07:00
|
|
|
|
Attributes:
|
2022-06-06 18:44:45 -07:00
|
|
|
|
device : Which device this shard resides on.
|
|
|
|
|
index : The index into the global array of this shard.
|
|
|
|
|
replica_id : Integer id indicating which replica of the global array this
|
|
|
|
|
shard is part of. Always 0 for fully sharded data
|
|
|
|
|
(i.e. when there’s only 1 replica).
|
|
|
|
|
data : The data of this shard. None if ``device`` is non-local.
|
|
|
|
|
"""
|
2022-06-14 10:34:19 -07:00
|
|
|
|
|
|
|
|
|
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
|
|
|
|
|
data: Optional[Array] = None):
|
|
|
|
|
self.device = device
|
|
|
|
|
self._sharding = sharding
|
|
|
|
|
self._global_shape = global_shape
|
|
|
|
|
self.data = data
|
|
|
|
|
|
2022-06-22 02:25:34 -07:00
|
|
|
|
def __repr__(self):
|
|
|
|
|
try:
|
|
|
|
|
return (f'Shard(device={repr(self.device)}, index={self.index}, '
|
|
|
|
|
f'replica_id={self.replica_id}, data={self.data})')
|
|
|
|
|
except ValueError:
|
|
|
|
|
return f'Shard(device={repr(self.device)}, data={self.data})'
|
|
|
|
|
|
2022-06-14 10:34:19 -07:00
|
|
|
|
@property
|
|
|
|
|
def index(self) -> Index:
|
|
|
|
|
try:
|
|
|
|
|
device_indices_fn = self._sharding.device_indices
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError('Cannot calculate indices from sharding: '
|
|
|
|
|
f'{self._sharding}. Please create a device to index '
|
|
|
|
|
'mapping for your sharding.') from None
|
|
|
|
|
index = device_indices_fn(self.device, self._global_shape)
|
|
|
|
|
assert index is not None
|
|
|
|
|
return index
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def replica_id(self) -> int:
|
|
|
|
|
try:
|
|
|
|
|
device_replica_id_fn = self._sharding.device_replica_id_map # pytype: disable=attribute-error
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError('Cannot calculate replica ids from sharding: '
|
|
|
|
|
f'{self._sharding}. Please create a device to replica id '
|
|
|
|
|
'mapping for your sharding.') from None
|
|
|
|
|
return device_replica_id_fn(self._global_shape)[self.device]
|
2022-06-06 18:44:45 -07:00
|
|
|
|
|
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
class Array:
|
|
|
|
|
# TODO(yashkatariya): Add __slots__ here.
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
|
2022-06-22 09:20:26 -07:00
|
|
|
|
arrays: Union[Sequence[DeviceArray], Sequence[Array]], committed: bool):
|
2022-08-17 12:25:14 -07:00
|
|
|
|
self.aval = aval
|
2022-06-06 17:31:20 -07:00
|
|
|
|
self._sharding = sharding
|
2022-06-22 09:20:26 -07:00
|
|
|
|
# Extract DeviceArrays from arrays with `SingleDeviceSharding` to keep the
|
|
|
|
|
# code handling `self._arrays` simpler.
|
|
|
|
|
# TODO(yashkatariya): This will be slower as it will happen during
|
|
|
|
|
# `__init__` on single controller environment. Make it lazy.
|
|
|
|
|
self._arrays: List[DeviceArray] = [a if isinstance(a, DeviceArray) else a._arrays[0]
|
|
|
|
|
for a in arrays]
|
2022-06-06 17:31:20 -07:00
|
|
|
|
# See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
|
|
|
|
|
# for what committed means.
|
|
|
|
|
self._committed = committed
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._npy_value = None
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
if config.jax_enable_checks:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
assert all(db.dtype == self.dtype for db in self._arrays), (
|
2022-06-06 17:31:20 -07:00
|
|
|
|
"Input arrays to `Array` must have matching dtypes, "
|
2022-06-22 09:20:26 -07:00
|
|
|
|
f"got: {[db.dtype for db in self._arrays]}")
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
# Rearrange arrays based on the device assignment.
|
|
|
|
|
if isinstance(sharding, XLACompatibleSharding):
|
|
|
|
|
device_to_buffer = {db.device().id: db for db in self._arrays}
|
|
|
|
|
self._arrays = [device_to_buffer[device.id]
|
2022-08-04 09:59:10 -07:00
|
|
|
|
for device in self.sharding._addressable_device_assignment]
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def shape(self) -> Shape:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return self.aval.shape
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
@property
|
2022-08-17 12:25:14 -07:00
|
|
|
|
def dtype(self):
|
|
|
|
|
return self.aval.dtype
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
@property
|
|
|
|
|
def ndim(self):
|
|
|
|
|
return len(self.shape)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def size(self):
|
|
|
|
|
return prod(self.shape)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def sharding(self):
|
|
|
|
|
return self._sharding
|
|
|
|
|
|
2022-08-16 16:51:26 -07:00
|
|
|
|
def __str__(self):
|
|
|
|
|
return str(self._value)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
try:
|
|
|
|
|
return self.shape[0]
|
|
|
|
|
except IndexError as err:
|
|
|
|
|
raise TypeError("len() of unsized object") from err # same as numpy error
|
|
|
|
|
|
|
|
|
|
def __bool__(self):
|
|
|
|
|
return bool(self._value)
|
|
|
|
|
|
|
|
|
|
def __nonzero__(self):
|
|
|
|
|
return bool(self._value)
|
|
|
|
|
|
|
|
|
|
def __float__(self):
|
|
|
|
|
return self._value.__float__()
|
|
|
|
|
|
|
|
|
|
def __int__(self):
|
|
|
|
|
return self._value.__int__()
|
|
|
|
|
|
|
|
|
|
def __complex__(self):
|
|
|
|
|
return self._value.__complex__()
|
|
|
|
|
|
|
|
|
|
def __hex__(self):
|
|
|
|
|
assert self.ndim == 0, 'hex only works on scalar values'
|
|
|
|
|
return hex(self._value) # type: ignore
|
|
|
|
|
|
|
|
|
|
def __oct__(self):
|
|
|
|
|
assert self.ndim == 0, 'oct only works on scalar values'
|
|
|
|
|
return oct(self._value) # type: ignore
|
|
|
|
|
|
|
|
|
|
def __index__(self):
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return op.index(self._value)
|
2022-08-16 16:51:26 -07:00
|
|
|
|
|
|
|
|
|
def to_bytes(self, order="C"):
|
|
|
|
|
return self._value.tobytes(order)
|
|
|
|
|
|
|
|
|
|
def tolist(self):
|
|
|
|
|
return self._value.tolist()
|
|
|
|
|
|
|
|
|
|
def __format__(self, format_spec):
|
|
|
|
|
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
|
|
|
|
|
if self.ndim == 0:
|
|
|
|
|
return format(self._value[()], format_spec)
|
|
|
|
|
else:
|
|
|
|
|
return format(self._value, format_spec)
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
if self.ndim == 0:
|
|
|
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
|
|
|
else:
|
|
|
|
|
# chunk_iter is added to Array in lax_numpy.py similar to DA.
|
|
|
|
|
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore
|
|
|
|
|
|
|
|
|
|
def item(self):
|
|
|
|
|
if dtypes.issubdtype(self.dtype, np.complexfloating):
|
|
|
|
|
return complex(self)
|
|
|
|
|
elif dtypes.issubdtype(self.dtype, np.floating):
|
|
|
|
|
return float(self)
|
|
|
|
|
elif dtypes.issubdtype(self.dtype, np.integer):
|
|
|
|
|
return int(self)
|
|
|
|
|
elif dtypes.issubdtype(self.dtype, np.bool_):
|
|
|
|
|
return bool(self)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(self.dtype)
|
|
|
|
|
|
2022-06-17 13:11:52 -07:00
|
|
|
|
def __repr__(self):
|
|
|
|
|
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
|
2022-08-17 12:25:14 -07:00
|
|
|
|
if self.aval is not None and self.aval.weak_type:
|
|
|
|
|
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
|
|
|
|
|
else:
|
|
|
|
|
dtype_str = f'dtype={self.dtype.name})'
|
2022-06-17 13:11:52 -07:00
|
|
|
|
|
|
|
|
|
if self.is_fully_addressable():
|
|
|
|
|
line_width = np.get_printoptions()["linewidth"]
|
|
|
|
|
s = np.array2string(self._value, prefix=prefix, suffix=',',
|
|
|
|
|
separator=', ', max_line_width=line_width)
|
|
|
|
|
last_line_len = len(s) - s.rfind('\n') + 1
|
|
|
|
|
sep = ' '
|
|
|
|
|
if last_line_len + len(dtype_str) + 1 > line_width:
|
|
|
|
|
sep = ' ' * len(prefix)
|
|
|
|
|
return f"{prefix}{s},{sep}{dtype_str}"
|
|
|
|
|
else:
|
|
|
|
|
return f"{prefix}{self.shape}{dtype_str}"
|
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def is_fully_addressable(self) -> bool:
|
2022-07-15 16:12:42 -07:00
|
|
|
|
return self.sharding.is_fully_addressable()
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def __array__(self, dtype=None):
|
|
|
|
|
return np.asarray(self._value, dtype=dtype)
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
# TODO(yashkatariya): Remove this method when everyone is using devices().
|
|
|
|
|
def device(self) -> Device:
|
|
|
|
|
device_set = self.sharding.device_set
|
|
|
|
|
if len(device_set) == 1:
|
|
|
|
|
single_device, = device_set
|
|
|
|
|
return single_device
|
|
|
|
|
raise ValueError('Length of devices is greater than 1. '
|
|
|
|
|
'Please use `.devices()`.')
|
|
|
|
|
|
|
|
|
|
def devices(self) -> List[Device]:
|
|
|
|
|
return list(self.sharding.device_set)
|
|
|
|
|
|
2022-06-06 18:44:45 -07:00
|
|
|
|
@pxla.maybe_cached_property
|
|
|
|
|
def addressable_shards(self) -> Sequence[Shard]:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-06-06 18:44:45 -07:00
|
|
|
|
out = []
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
db = pxla._set_aval(db)
|
|
|
|
|
device = db.device()
|
|
|
|
|
# Wrap the device arrays in `Array` until C++ returns an Array instead
|
|
|
|
|
# of a DA.
|
2022-08-17 12:25:14 -07:00
|
|
|
|
array = Array(db.aval, SingleDeviceSharding(device), [db], committed=True)
|
2022-06-14 10:34:19 -07:00
|
|
|
|
out.append(Shard(device, self.sharding, self.shape, array))
|
2022-06-06 18:44:45 -07:00
|
|
|
|
return out
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
def delete(self):
|
|
|
|
|
if self._arrays is None:
|
|
|
|
|
return
|
|
|
|
|
for buf in self._arrays:
|
|
|
|
|
buf.delete()
|
|
|
|
|
self._arrays = None
|
|
|
|
|
self._npy_value = None
|
|
|
|
|
|
|
|
|
|
def _check_if_deleted(self):
|
|
|
|
|
if self._arrays is None:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
raise RuntimeError("Array has been deleted.")
|
2022-06-13 18:07:55 -07:00
|
|
|
|
|
|
|
|
|
def block_until_ready(self):
|
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
for db in self._arrays:
|
|
|
|
|
db.block_until_ready()
|
|
|
|
|
return self
|
|
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def copy_to_host_async(self):
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._check_if_deleted()
|
|
|
|
|
if self._npy_value is None:
|
2022-06-14 10:34:19 -07:00
|
|
|
|
try:
|
|
|
|
|
self.addressable_shards[0].replica_id
|
|
|
|
|
replica_id_exists = True
|
|
|
|
|
except ValueError:
|
|
|
|
|
replica_id_exists = False
|
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
for s in self.addressable_shards:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
if not replica_id_exists or s.replica_id == 0:
|
2022-06-14 10:34:19 -07:00
|
|
|
|
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
@property
|
2022-06-06 17:31:20 -07:00
|
|
|
|
def _value(self) -> np.ndarray:
|
2022-06-13 18:07:55 -07:00
|
|
|
|
self._check_if_deleted()
|
2022-06-06 17:31:20 -07:00
|
|
|
|
if not self.is_fully_addressable():
|
|
|
|
|
raise RuntimeError("Fetching value for `jax.Array` that spans "
|
|
|
|
|
"non-addressable devices is not possible. You can use "
|
|
|
|
|
"`jax.experimental.multihost_utils.process_allgather` "
|
|
|
|
|
"for this use case.")
|
2022-06-13 18:07:55 -07:00
|
|
|
|
if self._npy_value is None:
|
|
|
|
|
self.copy_to_host_async()
|
|
|
|
|
npy_value = np.empty(self.shape, self.dtype)
|
2022-06-14 10:34:19 -07:00
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
self.addressable_shards[0].replica_id
|
|
|
|
|
replica_id_exists = True
|
|
|
|
|
except ValueError:
|
|
|
|
|
replica_id_exists = False
|
|
|
|
|
|
2022-06-13 18:07:55 -07:00
|
|
|
|
for s in self.addressable_shards:
|
2022-06-14 11:23:07 -07:00
|
|
|
|
if not replica_id_exists or s.replica_id == 0:
|
2022-06-13 18:07:55 -07:00
|
|
|
|
npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr]
|
|
|
|
|
self._npy_value = npy_value # type: ignore
|
|
|
|
|
# https://docs.python.org/3/library/typing.html#typing.cast
|
|
|
|
|
return cast(np.ndarray, self._npy_value)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
# explicitly set to be unhashable. Same as what device_array.py does.
|
|
|
|
|
setattr(Array, "__hash__", None)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
|
def make_array_from_callback(shape: Shape, sharding: Sharding,
|
|
|
|
|
data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array:
|
2022-06-22 09:20:26 -07:00
|
|
|
|
arrays = [
|
2022-06-06 17:31:20 -07:00
|
|
|
|
device_put(data_callback(sharding.device_indices(device, shape)), device)
|
|
|
|
|
for device in sharding.addressable_devices
|
|
|
|
|
]
|
2022-08-17 12:25:14 -07:00
|
|
|
|
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
|
|
|
|
return Array(aval, sharding, arrays, committed=True)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
|
|
|
|
|
2022-08-17 12:25:14 -07:00
|
|
|
|
core.pytype_aval_mappings[Array] = op.attrgetter('aval')
|
|
|
|
|
xla.pytype_aval_mappings[Array] = op.attrgetter('aval')
|
2022-06-10 07:31:43 -07:00
|
|
|
|
xla.canonicalize_dtype_handlers[Array] = pxla.identity
|
2022-08-17 12:25:14 -07:00
|
|
|
|
api_util._shaped_abstractify_handlers[Array] = op.attrgetter('aval')
|
2022-08-12 12:09:22 -07:00
|
|
|
|
ad_util.jaxval_adders[Array] = lax_internal.add
|
|
|
|
|
ad_util.jaxval_zeros_likers[Array] = lax_internal.zeros_like_array
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _array_mlir_constant_handler(val, canonicalize_types=True):
|
|
|
|
|
return mlir.ir_constants(val._value,
|
|
|
|
|
canonicalize_types=canonicalize_types)
|
|
|
|
|
mlir.register_constant_handler(Array, _array_mlir_constant_handler)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
|
|
|
|
|
def _device_put_array(x, device: Optional[Device]):
|
2022-06-28 12:48:39 -07:00
|
|
|
|
# TODO(yashkatariya): Remove this restriction and the round trip via host
|
|
|
|
|
# once lowering to XLA goes through `lower_mesh_computation`.
|
|
|
|
|
assert x.is_fully_addressable()
|
|
|
|
|
if isinstance(x.sharding, SingleDeviceSharding):
|
|
|
|
|
x = dispatch._copy_device_array_to_device(pxla._set_aval(x._arrays[0]), device)
|
|
|
|
|
return (x,)
|
|
|
|
|
else:
|
|
|
|
|
# Round trip via host if x is sharded. SDA also does a round trip via host.
|
|
|
|
|
return dispatch._device_put_array(x._value, device)
|
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
dispatch.device_put_handlers[Array] = _device_put_array
|
|
|
|
|
|
|
|
|
|
|
2022-08-10 20:11:06 -07:00
|
|
|
|
def _array_shard_arg(x, devices, indices, mode):
|
|
|
|
|
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
|
|
|
|
|
# code paths.
|
|
|
|
|
if mode == pxla.InputsHandlerMode.pmap:
|
|
|
|
|
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
|
|
|
|
|
# `_check_in_pmap_sharding_with_arrays` function.
|
2022-08-12 12:09:22 -07:00
|
|
|
|
if isinstance(x.sharding, SingleDeviceSharding):
|
|
|
|
|
return pxla._shard_device_array(x, devices, indices, mode)
|
2022-08-10 20:11:06 -07:00
|
|
|
|
return [buf if buf.device() == d else buf.copy_to_device(d)
|
|
|
|
|
for buf, d in safe_zip(x._arrays, devices)]
|
|
|
|
|
else:
|
|
|
|
|
return x._arrays
|
2022-06-10 07:31:43 -07:00
|
|
|
|
pxla.shard_arg_handlers[Array] = _array_shard_arg
|
|
|
|
|
|
|
|
|
|
|
2022-08-10 20:11:06 -07:00
|
|
|
|
def _array_global_result_handler(global_aval, out_sharding):
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True)
|
2022-08-10 20:11:06 -07:00
|
|
|
|
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
|
|
|
|
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _array_local_result_handler(aval, sharding, indices):
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return lambda bufs: Array(aval, sharding, bufs, committed=True)
|
2022-08-10 20:11:06 -07:00
|
|
|
|
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
|
|
|
|
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
|