[JAX] Delete ShardedDeviceArray.

Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime.

PiperOrigin-RevId: 518074394
This commit is contained in:
Peter Hawkins 2023-03-20 14:17:25 -07:00 committed by jax authors
parent 143dfcd74b
commit 926e42e025
8 changed files with 28 additions and 269 deletions

View File

@ -1820,7 +1820,7 @@ class _PmapFastpathData(NamedTuple):
input_devices: Sequence[xc.Device]
input_indices: Sequence[pxla.Index]
input_array_shardings: Sequence[Any]
# Data needed to build the ShardedDeviceArray from C++.
# Data needed to build the Array from C++.
out_sharding_specs: Sequence[pxla.ShardingSpec]
out_indices: Sequence[pxla.Index]
out_avals: Sequence[Any]
@ -1891,8 +1891,7 @@ def _cpp_pmap(
# TODO(sharadmv): Enable effects in replicated computation
not execute_replicated.has_unordered_effects
and not execute_replicated.has_host_callbacks and
# No tracers in the outputs. Checking for ShardedDeviceArray should be
# sufficient, but we use the more general `DeviceArray`.
# No tracers in the outputs.
all(
isinstance(x, device_array.DeviceArray) or isinstance(x, xc.ArrayImpl)
for x in out_flat))
@ -2528,7 +2527,7 @@ def device_put(
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array shards to specified devices and form ShardedDeviceArray(s).
"""Transfer array shards to specified devices and form Array(s).
Args:
shards: A sequence of arrays, scalars, or (nested) standard Python
@ -2540,7 +2539,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
This function is always asynchronous, i.e. returns immediately.
Returns:
A ShardedDeviceArray or (nested) Python container thereof representing the
A Array or (nested) Python container thereof representing the
elements of ``shards`` stacked together, with each shard backed by physical
device memory specified by the corresponding entry in ``devices``.
@ -2603,7 +2602,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array(s) to each specified device and form ShardedDeviceArray(s).
"""Transfer array(s) to each specified device and form Array(s).
Args:
x: an array, scalar, or (nested) standard Python container thereof
@ -2614,7 +2613,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
This function is always asynchronous, i.e. returns immediately.
Returns:
A ShardedDeviceArray or (nested) Python container thereof representing the
An Array or (nested) Python container thereof representing the
value of ``x`` broadcasted along a new leading axis of size
``len(devices)``, with each slice along that new leading axis backed by
memory on the device specified by the corresponding entry in ``devices``.

View File

@ -38,11 +38,11 @@ from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import operator as op
import sys
import threading
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
import numpy as np
@ -51,9 +51,7 @@ from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
@ -383,18 +381,8 @@ def shard_arg(arg, devices, arg_indices, sharding):
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
"""
if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices:
# The shard_arg_handlers allow an extensible set of types to be sharded, but
# inline handling for ShardedDeviceArray as a special case for performance
# NOTE: we compare indices instead of sharding_spec because
# pmap_benchmark.pmapshard_args_benchmark indicates this is faster.
return [
buf if buf.device() == d else buf.copy_to_device(d)
for d, buf in zip(devices, arg.device_buffers)
]
else:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
@profiler.annotate_function
@ -626,9 +614,6 @@ global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
# TODO(jblespiau): Consider removing this option.
_USE_CPP_SDA = True
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
if sharded_dim is not None:
@ -685,92 +670,15 @@ def make_sharded_device_array(
return jax.make_array_from_single_device_arrays(
aval.shape, sharding, device_buffers) # type: ignore
if _USE_CPP_SDA:
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore
ShardedDeviceArrayBase.__bases__)
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
if TYPE_CHECKING:
ShardedDeviceArray = Any
else:
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
basearray.Array.register(_SDA_BASE_CLASS)
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
"""A ShardedDeviceArray is an ndarray sharded across devices.
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
executing replicated computations, by allowing results to persist on the
devices that produced them. That way dispatching a similarly replicated
computation that consumes the same sharded memory layout does not incur any
transfers.
A ShardedDeviceArray represents one logical ndarray value, and simulates the
behavior of an ndarray so that it can be treated by user code as an ndarray;
that is, it is only an optimization to reduce transfers.
Attributes:
aval: A ShapedArray indicating the shape and dtype of this array.
sharding_spec: describes how this array is sharded across `device_buffers`.
device_buffers: the buffers containing the data for this array. Each buffer
is the same shape and on a different device. Buffers are in row-major
order, with replication treated as an extra innermost dimension.
indices: the result of spec_to_indices(sharding_spec). Can optionally be
precomputed for efficiency. A list the same length as
`device_buffers`. Each index indicates what portion of the full array is
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
np.asarray(device_buffers[i])`.
"""
__slots__ = [
"aval", "device_buffers", "sharding_spec", "indices",
"_one_replica_buffer_indices", "_npy_value"
]
def __init__(self,
aval: ShapedArray,
sharding_spec: ShardingSpec,
device_buffers: List[xb.xla_client.Buffer],
indices: Optional[Tuple[Index, ...]] = None):
super().__init__()
# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
self.aval = aval
self.device_buffers = device_buffers
self.sharding_spec = sharding_spec
self.indices = indices
self._npy_value = None
self._one_replica_buffer_indices = None
if config.jax_enable_checks:
assert type(aval) is ShapedArray
@property
def shape(self):
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
@property
def size(self):
return math.prod(self.aval.shape)
@property
def ndim(self):
return len(self.aval.shape)
def delete(self):
if self.device_buffers is None:
return
for buf in self.device_buffers:
buf.delete()
self.device_buffers = None
self._npy_value = None
class ShardedDeviceArray(object):
def __init__(self):
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
"and cannot be instantiated.")
def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
"""Returns a set of buffer-indices containing one complete copy of the array."""
@ -783,120 +691,6 @@ def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
seen_index_hashes.add(hashed_index)
return one_replica_indices
def _sda_one_replica_buffer_indices(self):
"""Indices of buffers containing one complete copy of the array data."""
if self._one_replica_buffer_indices is None:
self._one_replica_buffer_indices = _one_replica_buffer_indices(self.indices)
return self._one_replica_buffer_indices
def _sda_copy_to_host_async(self):
for buffer_index in self.one_replica_buffer_indices:
self.device_buffers[buffer_index].copy_to_host_async()
def _sda_check_if_deleted(self):
if self.device_buffers is None:
raise ValueError("ShardedDeviceArray has been deleted.")
def _sda_block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_until_ready()
return self
def _sda_value(self):
if self._npy_value is None:
self.copy_to_host_async()
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
self._npy_value = npy_value
return self._npy_value
def _sda__getitem__(self, idx):
self._check_if_deleted()
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
if self._npy_value is None:
try:
buf_idx = self.indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.shape, self.aval.dtype)
return device_array.make_device_array(aval, None, buf)
return super(self.__class__, self).__getitem__(idx)
def _sda__iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0]))
def _sda__reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0] - 1, -1, -1))
def _sda_sharding(self):
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
if has_unstacked:
devices = np.array([d.device() for d in self.device_buffers])
return sharding_impls.PmapSharding(devices, self.sharding_spec)
raise NotImplementedError(
'SDAs that are the output of pjit/xmap do not have the sharding attribute '
'implemented. If you are trying to pass the SDA to pjit/xmap, please '
'use multihost_utils.host_local_array_to_global_array(...) to convert '
'SDAs to global `jax.Array` and then pass them to pjit/xmap with '
'`jax_array` enabled.')
# TODO(yashkatariya): Remove this when SDA is deleted. The local import of Array
# will also go away.
def _sda_addressable_shards(self):
from jax._src import array
out = []
for db in self.device_buffers:
db = dispatch._set_aval(db)
out.append(array.Shard(db.device(), self.sharding, self.shape, db))
return out
for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
setattr(sda, "one_replica_buffer_indices",
property(_sda_one_replica_buffer_indices))
setattr(sda, "copy_to_host_async", _sda_copy_to_host_async)
setattr(sda, "_check_if_deleted", _sda_check_if_deleted)
setattr(sda, "block_until_ready", _sda_block_until_ready)
setattr(sda, "_value", property(_sda_value))
setattr(sda, "__getitem__", _sda__getitem__)
setattr(sda, "__iter__", _sda__iter__)
setattr(sda, "__reversed__", _sda__reversed__)
setattr(sda, "sharding", property(_sda_sharding))
setattr(sda, "addressable_shards", property(_sda_addressable_shards))
del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__,
_sda_sharding, _sda_addressable_shards)
ShardedDeviceArray: Type[object]
if _USE_CPP_SDA:
ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase
else:
ShardedDeviceArray = _ShardedDeviceArray
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
@ -936,24 +730,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
return batched_device_put(x.aval, sharding, bufs, devices)
def _sharded_device_array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(np.asarray(val),
canonicalize_types=canonicalize_types)
def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = shard_sharded_device_array_slow_path
mlir.register_constant_handler(sda,
_sharded_device_array_mlir_constant_handler)
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity
api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval")
_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
@ -1049,7 +825,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax.disable_jit(False):
donate_argnums_ = donate_argnums
if isinstance(outval, (ShardedDeviceArray, array.ArrayImpl)):
if isinstance(outval, array.ArrayImpl):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(

View File

@ -62,7 +62,6 @@ from jax._src.lax.utils import (
standard_primitive,
standard_translate,
)
from jax._src.lib import pmap_lib
from jax._src.lib import pytree
from jax._src import xla_bridge
from jax._src.lib import xla_client, xla_extension_version
@ -1529,14 +1528,10 @@ def zeros_like_array(x: ArrayLike) -> Array:
for t in itertools.chain(
dtypes.python_scalar_dtypes.keys(), array_types,
device_array.device_array_types, [array.ArrayImpl],
[pxla.ShardedDeviceArray, pxla._ShardedDeviceArray,
pmap_lib.ShardedDeviceArray]):
device_array.device_array_types, [array.ArrayImpl]):
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array
ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[pmap_lib.ShardedDeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[array.ArrayImpl] = zeros_like_array

View File

@ -5692,6 +5692,4 @@ def _set_device_array_attributes(device_array):
for t in device_array.device_array_types:
_set_device_array_attributes(t)
_set_device_array_attributes(pxla._ShardedDeviceArray)
_set_device_array_attributes(pmap_lib.ShardedDeviceArray)
_set_device_array_attributes(ArrayImpl)

View File

@ -45,7 +45,6 @@ from jax._src.interpreters.pxla import (
ShardInfo as ShardInfo,
ShardedAxis as ShardedAxis,
ShardedDeviceArray as ShardedDeviceArray,
ShardedDeviceArrayBase as ShardedDeviceArrayBase,
ShardingSpec as ShardingSpec,
TileManual as TileManual,
TileVectorize as TileVectorize,
@ -54,7 +53,6 @@ from jax._src.interpreters.pxla import (
UnloadedPmapExecutable as UnloadedPmapExecutable,
Unstacked as Unstacked,
WeakRefList as WeakRefList,
_ShardedDeviceArray as _ShardedDeviceArray,
_UNSPECIFIED as _UNSPECIFIED,
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
_get_and_check_device_assignment as _get_and_check_device_assignment,

View File

@ -54,7 +54,6 @@ from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.sharding import PartitionSpec as P
@ -1489,13 +1488,9 @@ class APITest(jtu.JaxTestCase):
def test_is_subclass(self):
self.assertTrue(issubclass(device_array.DeviceArray, jax.Array))
self.assertTrue(issubclass(device_array.Buffer, jax.Array))
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jax.Array))
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jax.Array))
self.assertFalse(issubclass(np.ndarray, jax.Array))
self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray))
self.assertFalse(issubclass(device_array.Buffer, np.ndarray))
self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray))
self.assertFalse(issubclass(pxla._ShardedDeviceArray, np.ndarray))
def test_is_instance(self):
def f(x):

View File

@ -108,8 +108,6 @@ class JaxJitTest(jtu.JaxTestCase):
sda = pmaped_f(np.asarray([[1]]))
output_buffer = device_put_function(sda, device=device)
self.assertNotIsInstance(output_buffer,
jax.interpreters.pxla.ShardedDeviceArray)
self.assertEqual(output_buffer.dtype, sda.dtype)
self.assertEqual(output_buffer.aval, sda.aval)
np.testing.assert_array_equal(output_buffer, np.asarray(sda))

View File

@ -759,14 +759,14 @@ class PythonPmapTest(jtu.JaxTestCase):
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
def testShardedDeviceArrays(self):
def testArrays(self):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# test that we can pass in and out ShardedDeviceArrays
# test that we can pass in and out Arrays
y = f(x)
self.assertIsInstance(y, jax.Array)
self.assertIsInstance(y, array.ArrayImpl)
@ -782,7 +782,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIsInstance(y, array.ArrayImpl)
self.assertAllClose(y, 2 * x, check_dtypes=False)
# test that we can pass a ShardedDeviceArray to a regular jit computation
# test that we can pass an Array to a regular jit computation
z = y + y
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
@ -809,7 +809,7 @@ class PythonPmapTest(jtu.JaxTestCase):
for in_shape, out_shape in [
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
])
def testShardedDeviceArrayReshape(self, in_shape, out_shape):
def testArrayReshape(self, in_shape, out_shape):
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
raise SkipTest("not enough devices")
@ -1551,7 +1551,7 @@ class PythonPmapTest(jtu.JaxTestCase):
1, 2).reshape(shape)
self.assertAllClose(fn(x, w), expected, check_dtypes=False)
def testShardedDeviceArrayBlockUntilReady(self):
def testArrayBlockUntilReady(self):
x = np.arange(jax.device_count())
x = self.pmap(lambda x: x)(x)
x.block_until_ready() # doesn't crash
@ -1602,7 +1602,7 @@ class PythonPmapTest(jtu.JaxTestCase):
multi_step_pmap(jnp.zeros((device_count,)), count=1)
def testShardedDeviceArrayGetItem(self):
def testArrayGetItem(self):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
@ -2612,7 +2612,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
@jtu.pytest_mark_if_available('multiaccelerator')
class ShardedDeviceArrayTest(jtu.JaxTestCase):
class ArrayTest(jtu.JaxTestCase):
def testThreadsafeIndexing(self):
# NOTE(skye): I picked these values to be big enough to cause interesting
@ -2868,7 +2868,7 @@ class ShardArgsTest(jtu.JaxTestCase):
def device_array(x):
return jax.device_put(x)
# TODO(skye): add coverage for ShardedDeviceArrays
# TODO(skye): add coverage for Arrays
@parameterized.named_parameters(
{"testcase_name":