diff --git a/jax/_src/api.py b/jax/_src/api.py index cb5d8d58f..fb723d027 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1820,7 +1820,7 @@ class _PmapFastpathData(NamedTuple): input_devices: Sequence[xc.Device] input_indices: Sequence[pxla.Index] input_array_shardings: Sequence[Any] - # Data needed to build the ShardedDeviceArray from C++. + # Data needed to build the Array from C++. out_sharding_specs: Sequence[pxla.ShardingSpec] out_indices: Sequence[pxla.Index] out_avals: Sequence[Any] @@ -1891,8 +1891,7 @@ def _cpp_pmap( # TODO(sharadmv): Enable effects in replicated computation not execute_replicated.has_unordered_effects and not execute_replicated.has_host_callbacks and - # No tracers in the outputs. Checking for ShardedDeviceArray should be - # sufficient, but we use the more general `DeviceArray`. + # No tracers in the outputs. all( isinstance(x, device_array.DeviceArray) or isinstance(x, xc.ArrayImpl) for x in out_flat)) @@ -2528,7 +2527,7 @@ def device_put( def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811 - """Transfer array shards to specified devices and form ShardedDeviceArray(s). + """Transfer array shards to specified devices and form Array(s). Args: shards: A sequence of arrays, scalars, or (nested) standard Python @@ -2540,7 +2539,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # This function is always asynchronous, i.e. returns immediately. Returns: - A ShardedDeviceArray or (nested) Python container thereof representing the + A Array or (nested) Python container thereof representing the elements of ``shards`` stacked together, with each shard backed by physical device memory specified by the corresponding entry in ``devices``. @@ -2603,7 +2602,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 - """Transfer array(s) to each specified device and form ShardedDeviceArray(s). + """Transfer array(s) to each specified device and form Array(s). Args: x: an array, scalar, or (nested) standard Python container thereof @@ -2614,7 +2613,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 This function is always asynchronous, i.e. returns immediately. Returns: - A ShardedDeviceArray or (nested) Python container thereof representing the + An Array or (nested) Python container thereof representing the value of ``x`` broadcasted along a new leading axis of size ``len(devices)``, with each slice along that new leading axis backed by memory on the device specified by the corresponding entry in ``devices``. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c6f3498f7..983245752 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -38,11 +38,11 @@ from functools import partial, lru_cache, cached_property import itertools as it import logging import math -import operator as op import sys import threading from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, - Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast) + Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast, + TYPE_CHECKING) import numpy as np @@ -51,9 +51,7 @@ from jax.errors import JAXTypeError from jax.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map -from jax._src import abstract_arrays from jax._src import api_util -from jax._src import basearray from jax._src import core from jax._src import device_array from jax._src import dispatch @@ -383,18 +381,8 @@ def shard_arg(arg, devices, arg_indices, sharding): devices: The list of devices to shard over. arg_indices: A list of `len(devices)` indices to use to shard the argument. """ - if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices: - # The shard_arg_handlers allow an extensible set of types to be sharded, but - # inline handling for ShardedDeviceArray as a special case for performance - # NOTE: we compare indices instead of sharding_spec because - # pmap_benchmark.pmapshard_args_benchmark indicates this is faster. - return [ - buf if buf.device() == d else buf.copy_to_device(d) - for d, buf in zip(devices, arg.device_buffers) - ] - else: - arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding) + arg = xla.canonicalize_dtype(arg) + return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding) @profiler.annotate_function @@ -626,9 +614,6 @@ global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} ### lazy device-memory persistence and result handling -# TODO(jblespiau): Consider removing this option. -_USE_CPP_SDA = True - def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None): if sharded_dim is not None: @@ -685,92 +670,15 @@ def make_sharded_device_array( return jax.make_array_from_single_device_arrays( aval.shape, sharding, device_buffers) # type: ignore -if _USE_CPP_SDA: - ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore - # We want the C++ SDA to extend the DeviceArrayBase. We want this both to - # benefit from its methods, and to have isinstance(x, DeviceArray) return true - ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore - ShardedDeviceArrayBase.__bases__) - _SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore + + +if TYPE_CHECKING: + ShardedDeviceArray = Any else: - _SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore -basearray.Array.register(_SDA_BASE_CLASS) - - -class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore - """A ShardedDeviceArray is an ndarray sharded across devices. - - The purpose of a ShardedDeviceArray is to reduce the number of transfers when - executing replicated computations, by allowing results to persist on the - devices that produced them. That way dispatching a similarly replicated - computation that consumes the same sharded memory layout does not incur any - transfers. - - A ShardedDeviceArray represents one logical ndarray value, and simulates the - behavior of an ndarray so that it can be treated by user code as an ndarray; - that is, it is only an optimization to reduce transfers. - - Attributes: - aval: A ShapedArray indicating the shape and dtype of this array. - sharding_spec: describes how this array is sharded across `device_buffers`. - device_buffers: the buffers containing the data for this array. Each buffer - is the same shape and on a different device. Buffers are in row-major - order, with replication treated as an extra innermost dimension. - indices: the result of spec_to_indices(sharding_spec). Can optionally be - precomputed for efficiency. A list the same length as - `device_buffers`. Each index indicates what portion of the full array is - stored in the corresponding device buffer, i.e. `array[indices[i]] == - np.asarray(device_buffers[i])`. - """ - __slots__ = [ - "aval", "device_buffers", "sharding_spec", "indices", - "_one_replica_buffer_indices", "_npy_value" - ] - - def __init__(self, - aval: ShapedArray, - sharding_spec: ShardingSpec, - device_buffers: List[xb.xla_client.Buffer], - indices: Optional[Tuple[Index, ...]] = None): - super().__init__() - - # TODO(skye): assert invariants. Keep performance in mind though. - if indices is None: - indices = spec_to_indices(aval.shape, sharding_spec) - - self.aval = aval - self.device_buffers = device_buffers - self.sharding_spec = sharding_spec - self.indices = indices - self._npy_value = None - self._one_replica_buffer_indices = None - if config.jax_enable_checks: - assert type(aval) is ShapedArray - - @property - def shape(self): - return self.aval.shape - - @property - def dtype(self): - return self.aval.dtype - - @property - def size(self): - return math.prod(self.aval.shape) - - @property - def ndim(self): - return len(self.aval.shape) - - def delete(self): - if self.device_buffers is None: - return - for buf in self.device_buffers: - buf.delete() - self.device_buffers = None - self._npy_value = None - + class ShardedDeviceArray(object): + def __init__(self): + raise RuntimeError("ShardedDeviceArray is a backward compatibility shim " + "and cannot be instantiated.") def _one_replica_buffer_indices(indices: Tuple[Index, ...]): """Returns a set of buffer-indices containing one complete copy of the array.""" @@ -783,120 +691,6 @@ def _one_replica_buffer_indices(indices: Tuple[Index, ...]): seen_index_hashes.add(hashed_index) return one_replica_indices - -def _sda_one_replica_buffer_indices(self): - """Indices of buffers containing one complete copy of the array data.""" - if self._one_replica_buffer_indices is None: - self._one_replica_buffer_indices = _one_replica_buffer_indices(self.indices) - return self._one_replica_buffer_indices - - -def _sda_copy_to_host_async(self): - for buffer_index in self.one_replica_buffer_indices: - self.device_buffers[buffer_index].copy_to_host_async() - - -def _sda_check_if_deleted(self): - if self.device_buffers is None: - raise ValueError("ShardedDeviceArray has been deleted.") - - -def _sda_block_until_ready(self): - self._check_if_deleted() - for buf in self.device_buffers: - buf.block_until_ready() - return self - - -def _sda_value(self): - if self._npy_value is None: - self.copy_to_host_async() - npy_value = np.empty(self.aval.shape, self.aval.dtype) - for i in self.one_replica_buffer_indices: - npy_value[self.indices[i]] = np.asarray(self.device_buffers[i]) - self._npy_value = npy_value - return self._npy_value - - -def _sda__getitem__(self, idx): - self._check_if_deleted() - if not isinstance(idx, tuple): - cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1) - else: - cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx)) - if self._npy_value is None: - try: - buf_idx = self.indices.index(cidx) - except ValueError: - buf_idx = None - if buf_idx is not None: - buf = self.device_buffers[buf_idx] - aval = ShapedArray(buf.shape, self.aval.dtype) - return device_array.make_device_array(aval, None, buf) - return super(self.__class__, self).__getitem__(idx) - - -def _sda__iter__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d array") # same as numpy error - else: - return (self[i] for i in range(self.shape[0])) - -def _sda__reversed__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d array") # same as numpy error - else: - return (self[i] for i in range(self.shape[0] - 1, -1, -1)) - - -def _sda_sharding(self): - has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding) - if has_unstacked: - devices = np.array([d.device() for d in self.device_buffers]) - return sharding_impls.PmapSharding(devices, self.sharding_spec) - raise NotImplementedError( - 'SDAs that are the output of pjit/xmap do not have the sharding attribute ' - 'implemented. If you are trying to pass the SDA to pjit/xmap, please ' - 'use multihost_utils.host_local_array_to_global_array(...) to convert ' - 'SDAs to global `jax.Array` and then pass them to pjit/xmap with ' - '`jax_array` enabled.') - -# TODO(yashkatariya): Remove this when SDA is deleted. The local import of Array -# will also go away. -def _sda_addressable_shards(self): - from jax._src import array - out = [] - for db in self.device_buffers: - db = dispatch._set_aval(db) - out.append(array.Shard(db.device(), self.sharding, self.shape, db)) - return out - - -for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]: - setattr(sda, "one_replica_buffer_indices", - property(_sda_one_replica_buffer_indices)) - setattr(sda, "copy_to_host_async", _sda_copy_to_host_async) - setattr(sda, "_check_if_deleted", _sda_check_if_deleted) - setattr(sda, "block_until_ready", _sda_block_until_ready) - setattr(sda, "_value", property(_sda_value)) - setattr(sda, "__getitem__", _sda__getitem__) - setattr(sda, "__iter__", _sda__iter__) - setattr(sda, "__reversed__", _sda__reversed__) - setattr(sda, "sharding", property(_sda_sharding)) - setattr(sda, "addressable_shards", property(_sda_addressable_shards)) - -del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async, - _sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__, - _sda_sharding, _sda_addressable_shards) - - -ShardedDeviceArray: Type[object] -if _USE_CPP_SDA: - ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase -else: - ShardedDeviceArray = _ShardedDeviceArray - - def _hashable_index(idx): return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) @@ -936,24 +730,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): return batched_device_put(x.aval, sharding, bufs, devices) - -def _sharded_device_array_mlir_constant_handler(val, canonicalize_types=True): - return mlir.ir_constants(np.asarray(val), - canonicalize_types=canonicalize_types) - -def _register_handlers_for_sharded_device_array(sda): - shard_arg_handlers[sda] = shard_sharded_device_array_slow_path - mlir.register_constant_handler(sda, - _sharded_device_array_mlir_constant_handler) - - core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval - xla.pytype_aval_mappings[sda] = op.attrgetter("aval") - xla.canonicalize_dtype_handlers[sda] = identity - api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval") - -_register_handlers_for_sharded_device_array(_ShardedDeviceArray) -_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray) - ### the xla_pmap primitive and its rules are comparable to xla_call in xla.py @@ -1049,7 +825,7 @@ def _emap_impl(fun: lu.WrappedFun, *args, for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals): with jax.disable_jit(False): donate_argnums_ = donate_argnums - if isinstance(outval, (ShardedDeviceArray, array.ArrayImpl)): + if isinstance(outval, array.ArrayImpl): # We don't want to donate if it's already sharded. donate_argnums_ = () out = jax.pmap( diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f595ccde2..67a9c14cc 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -62,7 +62,6 @@ from jax._src.lax.utils import ( standard_primitive, standard_translate, ) -from jax._src.lib import pmap_lib from jax._src.lib import pytree from jax._src import xla_bridge from jax._src.lib import xla_client, xla_extension_version @@ -1529,14 +1528,10 @@ def zeros_like_array(x: ArrayLike) -> Array: for t in itertools.chain( dtypes.python_scalar_dtypes.keys(), array_types, - device_array.device_array_types, [array.ArrayImpl], - [pxla.ShardedDeviceArray, pxla._ShardedDeviceArray, - pmap_lib.ShardedDeviceArray]): + device_array.device_array_types, [array.ArrayImpl]): ad_util.jaxval_adders[t] = add ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array -ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array -ad_util.jaxval_zeros_likers[pmap_lib.ShardedDeviceArray] = zeros_like_array ad_util.jaxval_zeros_likers[array.ArrayImpl] = zeros_like_array diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0355c1fd4..a698a9342 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5692,6 +5692,4 @@ def _set_device_array_attributes(device_array): for t in device_array.device_array_types: _set_device_array_attributes(t) -_set_device_array_attributes(pxla._ShardedDeviceArray) -_set_device_array_attributes(pmap_lib.ShardedDeviceArray) _set_device_array_attributes(ArrayImpl) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 998d81e53..437da3e1a 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -45,7 +45,6 @@ from jax._src.interpreters.pxla import ( ShardInfo as ShardInfo, ShardedAxis as ShardedAxis, ShardedDeviceArray as ShardedDeviceArray, - ShardedDeviceArrayBase as ShardedDeviceArrayBase, ShardingSpec as ShardingSpec, TileManual as TileManual, TileVectorize as TileVectorize, @@ -54,7 +53,6 @@ from jax._src.interpreters.pxla import ( UnloadedPmapExecutable as UnloadedPmapExecutable, Unstacked as Unstacked, WeakRefList as WeakRefList, - _ShardedDeviceArray as _ShardedDeviceArray, _UNSPECIFIED as _UNSPECIFIED, _create_pmap_sharding_spec as _create_pmap_sharding_spec, _get_and_check_device_assignment as _get_and_check_device_assignment, diff --git a/tests/api_test.py b/tests/api_test.py index 02bebd1bc..d7bc434fb 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -54,7 +54,6 @@ from jax.errors import UnexpectedTracerError from jax.interpreters import ad from jax._src.interpreters import mlir from jax.interpreters import xla -from jax.interpreters import pxla from jax.interpreters import batching from jax.interpreters import partial_eval as pe from jax.sharding import PartitionSpec as P @@ -1489,13 +1488,9 @@ class APITest(jtu.JaxTestCase): def test_is_subclass(self): self.assertTrue(issubclass(device_array.DeviceArray, jax.Array)) self.assertTrue(issubclass(device_array.Buffer, jax.Array)) - self.assertTrue(issubclass(pxla.ShardedDeviceArray, jax.Array)) - self.assertTrue(issubclass(pxla._ShardedDeviceArray, jax.Array)) self.assertFalse(issubclass(np.ndarray, jax.Array)) self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray)) self.assertFalse(issubclass(device_array.Buffer, np.ndarray)) - self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray)) - self.assertFalse(issubclass(pxla._ShardedDeviceArray, np.ndarray)) def test_is_instance(self): def f(x): diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index 0f99df21c..c61ed6735 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -108,8 +108,6 @@ class JaxJitTest(jtu.JaxTestCase): sda = pmaped_f(np.asarray([[1]])) output_buffer = device_put_function(sda, device=device) - self.assertNotIsInstance(output_buffer, - jax.interpreters.pxla.ShardedDeviceArray) self.assertEqual(output_buffer.dtype, sda.dtype) self.assertEqual(output_buffer.aval, sda.aval) np.testing.assert_array_equal(output_buffer, np.asarray(sda)) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 140dad5d8..7814a0070 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -759,14 +759,14 @@ class PythonPmapTest(jtu.JaxTestCase): expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x) self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3) - def testShardedDeviceArrays(self): + def testArrays(self): f = lambda x: 2 * x f = self.pmap(f, axis_name='i') shape = (jax.device_count(), 4) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - # test that we can pass in and out ShardedDeviceArrays + # test that we can pass in and out Arrays y = f(x) self.assertIsInstance(y, jax.Array) self.assertIsInstance(y, array.ArrayImpl) @@ -782,7 +782,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertIsInstance(y, array.ArrayImpl) self.assertAllClose(y, 2 * x, check_dtypes=False) - # test that we can pass a ShardedDeviceArray to a regular jit computation + # test that we can pass an Array to a regular jit computation z = y + y self.assertAllClose(z, 2 * 2 * x, check_dtypes=False) @@ -809,7 +809,7 @@ class PythonPmapTest(jtu.JaxTestCase): for in_shape, out_shape in [ [(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)] ]) - def testShardedDeviceArrayReshape(self, in_shape, out_shape): + def testArrayReshape(self, in_shape, out_shape): if jax.device_count() < max(in_shape[:1] + out_shape[:1]): raise SkipTest("not enough devices") @@ -1551,7 +1551,7 @@ class PythonPmapTest(jtu.JaxTestCase): 1, 2).reshape(shape) self.assertAllClose(fn(x, w), expected, check_dtypes=False) - def testShardedDeviceArrayBlockUntilReady(self): + def testArrayBlockUntilReady(self): x = np.arange(jax.device_count()) x = self.pmap(lambda x: x)(x) x.block_until_ready() # doesn't crash @@ -1602,7 +1602,7 @@ class PythonPmapTest(jtu.JaxTestCase): multi_step_pmap(jnp.zeros((device_count,)), count=1) - def testShardedDeviceArrayGetItem(self): + def testArrayGetItem(self): f = lambda x: 2 * x f = self.pmap(f, axis_name='i') @@ -2612,7 +2612,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): @jtu.pytest_mark_if_available('multiaccelerator') -class ShardedDeviceArrayTest(jtu.JaxTestCase): +class ArrayTest(jtu.JaxTestCase): def testThreadsafeIndexing(self): # NOTE(skye): I picked these values to be big enough to cause interesting @@ -2868,7 +2868,7 @@ class ShardArgsTest(jtu.JaxTestCase): def device_array(x): return jax.device_put(x) - # TODO(skye): add coverage for ShardedDeviceArrays + # TODO(skye): add coverage for Arrays @parameterized.named_parameters( {"testcase_name":