mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Delete ShardedDeviceArray.
Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime. PiperOrigin-RevId: 518074394
This commit is contained in:
parent
143dfcd74b
commit
926e42e025
@ -1820,7 +1820,7 @@ class _PmapFastpathData(NamedTuple):
|
||||
input_devices: Sequence[xc.Device]
|
||||
input_indices: Sequence[pxla.Index]
|
||||
input_array_shardings: Sequence[Any]
|
||||
# Data needed to build the ShardedDeviceArray from C++.
|
||||
# Data needed to build the Array from C++.
|
||||
out_sharding_specs: Sequence[pxla.ShardingSpec]
|
||||
out_indices: Sequence[pxla.Index]
|
||||
out_avals: Sequence[Any]
|
||||
@ -1891,8 +1891,7 @@ def _cpp_pmap(
|
||||
# TODO(sharadmv): Enable effects in replicated computation
|
||||
not execute_replicated.has_unordered_effects
|
||||
and not execute_replicated.has_host_callbacks and
|
||||
# No tracers in the outputs. Checking for ShardedDeviceArray should be
|
||||
# sufficient, but we use the more general `DeviceArray`.
|
||||
# No tracers in the outputs.
|
||||
all(
|
||||
isinstance(x, device_array.DeviceArray) or isinstance(x, xc.ArrayImpl)
|
||||
for x in out_flat))
|
||||
@ -2528,7 +2527,7 @@ def device_put(
|
||||
|
||||
|
||||
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811
|
||||
"""Transfer array shards to specified devices and form ShardedDeviceArray(s).
|
||||
"""Transfer array shards to specified devices and form Array(s).
|
||||
|
||||
Args:
|
||||
shards: A sequence of arrays, scalars, or (nested) standard Python
|
||||
@ -2540,7 +2539,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
This function is always asynchronous, i.e. returns immediately.
|
||||
|
||||
Returns:
|
||||
A ShardedDeviceArray or (nested) Python container thereof representing the
|
||||
A Array or (nested) Python container thereof representing the
|
||||
elements of ``shards`` stacked together, with each shard backed by physical
|
||||
device memory specified by the corresponding entry in ``devices``.
|
||||
|
||||
@ -2603,7 +2602,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
|
||||
|
||||
def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
"""Transfer array(s) to each specified device and form ShardedDeviceArray(s).
|
||||
"""Transfer array(s) to each specified device and form Array(s).
|
||||
|
||||
Args:
|
||||
x: an array, scalar, or (nested) standard Python container thereof
|
||||
@ -2614,7 +2613,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
This function is always asynchronous, i.e. returns immediately.
|
||||
|
||||
Returns:
|
||||
A ShardedDeviceArray or (nested) Python container thereof representing the
|
||||
An Array or (nested) Python container thereof representing the
|
||||
value of ``x`` broadcasted along a new leading axis of size
|
||||
``len(devices)``, with each slice along that new leading axis backed by
|
||||
memory on the device specified by the corresponding entry in ``devices``.
|
||||
|
@ -38,11 +38,11 @@ from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import operator as op
|
||||
import sys
|
||||
import threading
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
||||
TYPE_CHECKING)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -51,9 +51,7 @@ from jax.errors import JAXTypeError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
@ -383,18 +381,8 @@ def shard_arg(arg, devices, arg_indices, sharding):
|
||||
devices: The list of devices to shard over.
|
||||
arg_indices: A list of `len(devices)` indices to use to shard the argument.
|
||||
"""
|
||||
if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices:
|
||||
# The shard_arg_handlers allow an extensible set of types to be sharded, but
|
||||
# inline handling for ShardedDeviceArray as a special case for performance
|
||||
# NOTE: we compare indices instead of sharding_spec because
|
||||
# pmap_benchmark.pmapshard_args_benchmark indicates this is faster.
|
||||
return [
|
||||
buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for d, buf in zip(devices, arg.device_buffers)
|
||||
]
|
||||
else:
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
@ -626,9 +614,6 @@ global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
# TODO(jblespiau): Consider removing this option.
|
||||
_USE_CPP_SDA = True
|
||||
|
||||
|
||||
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
||||
if sharded_dim is not None:
|
||||
@ -685,92 +670,15 @@ def make_sharded_device_array(
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
aval.shape, sharding, device_buffers) # type: ignore
|
||||
|
||||
if _USE_CPP_SDA:
|
||||
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
||||
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
|
||||
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
|
||||
ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore
|
||||
ShardedDeviceArrayBase.__bases__)
|
||||
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ShardedDeviceArray = Any
|
||||
else:
|
||||
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
|
||||
basearray.Array.register(_SDA_BASE_CLASS)
|
||||
|
||||
|
||||
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
|
||||
"""A ShardedDeviceArray is an ndarray sharded across devices.
|
||||
|
||||
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
|
||||
executing replicated computations, by allowing results to persist on the
|
||||
devices that produced them. That way dispatching a similarly replicated
|
||||
computation that consumes the same sharded memory layout does not incur any
|
||||
transfers.
|
||||
|
||||
A ShardedDeviceArray represents one logical ndarray value, and simulates the
|
||||
behavior of an ndarray so that it can be treated by user code as an ndarray;
|
||||
that is, it is only an optimization to reduce transfers.
|
||||
|
||||
Attributes:
|
||||
aval: A ShapedArray indicating the shape and dtype of this array.
|
||||
sharding_spec: describes how this array is sharded across `device_buffers`.
|
||||
device_buffers: the buffers containing the data for this array. Each buffer
|
||||
is the same shape and on a different device. Buffers are in row-major
|
||||
order, with replication treated as an extra innermost dimension.
|
||||
indices: the result of spec_to_indices(sharding_spec). Can optionally be
|
||||
precomputed for efficiency. A list the same length as
|
||||
`device_buffers`. Each index indicates what portion of the full array is
|
||||
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
|
||||
np.asarray(device_buffers[i])`.
|
||||
"""
|
||||
__slots__ = [
|
||||
"aval", "device_buffers", "sharding_spec", "indices",
|
||||
"_one_replica_buffer_indices", "_npy_value"
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
aval: ShapedArray,
|
||||
sharding_spec: ShardingSpec,
|
||||
device_buffers: List[xb.xla_client.Buffer],
|
||||
indices: Optional[Tuple[Index, ...]] = None):
|
||||
super().__init__()
|
||||
|
||||
# TODO(skye): assert invariants. Keep performance in mind though.
|
||||
if indices is None:
|
||||
indices = spec_to_indices(aval.shape, sharding_spec)
|
||||
|
||||
self.aval = aval
|
||||
self.device_buffers = device_buffers
|
||||
self.sharding_spec = sharding_spec
|
||||
self.indices = indices
|
||||
self._npy_value = None
|
||||
self._one_replica_buffer_indices = None
|
||||
if config.jax_enable_checks:
|
||||
assert type(aval) is ShapedArray
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.aval.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.aval.dtype
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return math.prod(self.aval.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.aval.shape)
|
||||
|
||||
def delete(self):
|
||||
if self.device_buffers is None:
|
||||
return
|
||||
for buf in self.device_buffers:
|
||||
buf.delete()
|
||||
self.device_buffers = None
|
||||
self._npy_value = None
|
||||
|
||||
class ShardedDeviceArray(object):
|
||||
def __init__(self):
|
||||
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
|
||||
"and cannot be instantiated.")
|
||||
|
||||
def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
|
||||
"""Returns a set of buffer-indices containing one complete copy of the array."""
|
||||
@ -783,120 +691,6 @@ def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
|
||||
seen_index_hashes.add(hashed_index)
|
||||
return one_replica_indices
|
||||
|
||||
|
||||
def _sda_one_replica_buffer_indices(self):
|
||||
"""Indices of buffers containing one complete copy of the array data."""
|
||||
if self._one_replica_buffer_indices is None:
|
||||
self._one_replica_buffer_indices = _one_replica_buffer_indices(self.indices)
|
||||
return self._one_replica_buffer_indices
|
||||
|
||||
|
||||
def _sda_copy_to_host_async(self):
|
||||
for buffer_index in self.one_replica_buffer_indices:
|
||||
self.device_buffers[buffer_index].copy_to_host_async()
|
||||
|
||||
|
||||
def _sda_check_if_deleted(self):
|
||||
if self.device_buffers is None:
|
||||
raise ValueError("ShardedDeviceArray has been deleted.")
|
||||
|
||||
|
||||
def _sda_block_until_ready(self):
|
||||
self._check_if_deleted()
|
||||
for buf in self.device_buffers:
|
||||
buf.block_until_ready()
|
||||
return self
|
||||
|
||||
|
||||
def _sda_value(self):
|
||||
if self._npy_value is None:
|
||||
self.copy_to_host_async()
|
||||
npy_value = np.empty(self.aval.shape, self.aval.dtype)
|
||||
for i in self.one_replica_buffer_indices:
|
||||
npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
|
||||
self._npy_value = npy_value
|
||||
return self._npy_value
|
||||
|
||||
|
||||
def _sda__getitem__(self, idx):
|
||||
self._check_if_deleted()
|
||||
if not isinstance(idx, tuple):
|
||||
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
|
||||
else:
|
||||
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
|
||||
if self._npy_value is None:
|
||||
try:
|
||||
buf_idx = self.indices.index(cidx)
|
||||
except ValueError:
|
||||
buf_idx = None
|
||||
if buf_idx is not None:
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.shape, self.aval.dtype)
|
||||
return device_array.make_device_array(aval, None, buf)
|
||||
return super(self.__class__, self).__getitem__(idx)
|
||||
|
||||
|
||||
def _sda__iter__(self):
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
return (self[i] for i in range(self.shape[0]))
|
||||
|
||||
def _sda__reversed__(self):
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
return (self[i] for i in range(self.shape[0] - 1, -1, -1))
|
||||
|
||||
|
||||
def _sda_sharding(self):
|
||||
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding_spec.sharding)
|
||||
if has_unstacked:
|
||||
devices = np.array([d.device() for d in self.device_buffers])
|
||||
return sharding_impls.PmapSharding(devices, self.sharding_spec)
|
||||
raise NotImplementedError(
|
||||
'SDAs that are the output of pjit/xmap do not have the sharding attribute '
|
||||
'implemented. If you are trying to pass the SDA to pjit/xmap, please '
|
||||
'use multihost_utils.host_local_array_to_global_array(...) to convert '
|
||||
'SDAs to global `jax.Array` and then pass them to pjit/xmap with '
|
||||
'`jax_array` enabled.')
|
||||
|
||||
# TODO(yashkatariya): Remove this when SDA is deleted. The local import of Array
|
||||
# will also go away.
|
||||
def _sda_addressable_shards(self):
|
||||
from jax._src import array
|
||||
out = []
|
||||
for db in self.device_buffers:
|
||||
db = dispatch._set_aval(db)
|
||||
out.append(array.Shard(db.device(), self.sharding, self.shape, db))
|
||||
return out
|
||||
|
||||
|
||||
for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
|
||||
setattr(sda, "one_replica_buffer_indices",
|
||||
property(_sda_one_replica_buffer_indices))
|
||||
setattr(sda, "copy_to_host_async", _sda_copy_to_host_async)
|
||||
setattr(sda, "_check_if_deleted", _sda_check_if_deleted)
|
||||
setattr(sda, "block_until_ready", _sda_block_until_ready)
|
||||
setattr(sda, "_value", property(_sda_value))
|
||||
setattr(sda, "__getitem__", _sda__getitem__)
|
||||
setattr(sda, "__iter__", _sda__iter__)
|
||||
setattr(sda, "__reversed__", _sda__reversed__)
|
||||
setattr(sda, "sharding", property(_sda_sharding))
|
||||
setattr(sda, "addressable_shards", property(_sda_addressable_shards))
|
||||
|
||||
del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
|
||||
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__,
|
||||
_sda_sharding, _sda_addressable_shards)
|
||||
|
||||
|
||||
ShardedDeviceArray: Type[object]
|
||||
if _USE_CPP_SDA:
|
||||
ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase
|
||||
else:
|
||||
ShardedDeviceArray = _ShardedDeviceArray
|
||||
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
||||
|
||||
@ -936,24 +730,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
|
||||
return batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
def _sharded_device_array_mlir_constant_handler(val, canonicalize_types=True):
|
||||
return mlir.ir_constants(np.asarray(val),
|
||||
canonicalize_types=canonicalize_types)
|
||||
|
||||
def _register_handlers_for_sharded_device_array(sda):
|
||||
shard_arg_handlers[sda] = shard_sharded_device_array_slow_path
|
||||
mlir.register_constant_handler(sda,
|
||||
_sharded_device_array_mlir_constant_handler)
|
||||
|
||||
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
|
||||
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
|
||||
xla.canonicalize_dtype_handlers[sda] = identity
|
||||
api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval")
|
||||
|
||||
_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
|
||||
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)
|
||||
|
||||
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
||||
|
||||
|
||||
@ -1049,7 +825,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
||||
with jax.disable_jit(False):
|
||||
donate_argnums_ = donate_argnums
|
||||
if isinstance(outval, (ShardedDeviceArray, array.ArrayImpl)):
|
||||
if isinstance(outval, array.ArrayImpl):
|
||||
# We don't want to donate if it's already sharded.
|
||||
donate_argnums_ = ()
|
||||
out = jax.pmap(
|
||||
|
@ -62,7 +62,6 @@ from jax._src.lax.utils import (
|
||||
standard_primitive,
|
||||
standard_translate,
|
||||
)
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import pytree
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client, xla_extension_version
|
||||
@ -1529,14 +1528,10 @@ def zeros_like_array(x: ArrayLike) -> Array:
|
||||
|
||||
for t in itertools.chain(
|
||||
dtypes.python_scalar_dtypes.keys(), array_types,
|
||||
device_array.device_array_types, [array.ArrayImpl],
|
||||
[pxla.ShardedDeviceArray, pxla._ShardedDeviceArray,
|
||||
pmap_lib.ShardedDeviceArray]):
|
||||
device_array.device_array_types, [array.ArrayImpl]):
|
||||
ad_util.jaxval_adders[t] = add
|
||||
ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[pmap_lib.ShardedDeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[array.ArrayImpl] = zeros_like_array
|
||||
|
||||
|
||||
|
@ -5692,6 +5692,4 @@ def _set_device_array_attributes(device_array):
|
||||
|
||||
for t in device_array.device_array_types:
|
||||
_set_device_array_attributes(t)
|
||||
_set_device_array_attributes(pxla._ShardedDeviceArray)
|
||||
_set_device_array_attributes(pmap_lib.ShardedDeviceArray)
|
||||
_set_device_array_attributes(ArrayImpl)
|
||||
|
@ -45,7 +45,6 @@ from jax._src.interpreters.pxla import (
|
||||
ShardInfo as ShardInfo,
|
||||
ShardedAxis as ShardedAxis,
|
||||
ShardedDeviceArray as ShardedDeviceArray,
|
||||
ShardedDeviceArrayBase as ShardedDeviceArrayBase,
|
||||
ShardingSpec as ShardingSpec,
|
||||
TileManual as TileManual,
|
||||
TileVectorize as TileVectorize,
|
||||
@ -54,7 +53,6 @@ from jax._src.interpreters.pxla import (
|
||||
UnloadedPmapExecutable as UnloadedPmapExecutable,
|
||||
Unstacked as Unstacked,
|
||||
WeakRefList as WeakRefList,
|
||||
_ShardedDeviceArray as _ShardedDeviceArray,
|
||||
_UNSPECIFIED as _UNSPECIFIED,
|
||||
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
|
||||
_get_and_check_device_assignment as _get_and_check_device_assignment,
|
||||
|
@ -54,7 +54,6 @@ from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -1489,13 +1488,9 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_is_subclass(self):
|
||||
self.assertTrue(issubclass(device_array.DeviceArray, jax.Array))
|
||||
self.assertTrue(issubclass(device_array.Buffer, jax.Array))
|
||||
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jax.Array))
|
||||
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jax.Array))
|
||||
self.assertFalse(issubclass(np.ndarray, jax.Array))
|
||||
self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(device_array.Buffer, np.ndarray))
|
||||
self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(pxla._ShardedDeviceArray, np.ndarray))
|
||||
|
||||
def test_is_instance(self):
|
||||
def f(x):
|
||||
|
@ -108,8 +108,6 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
sda = pmaped_f(np.asarray([[1]]))
|
||||
output_buffer = device_put_function(sda, device=device)
|
||||
|
||||
self.assertNotIsInstance(output_buffer,
|
||||
jax.interpreters.pxla.ShardedDeviceArray)
|
||||
self.assertEqual(output_buffer.dtype, sda.dtype)
|
||||
self.assertEqual(output_buffer.aval, sda.aval)
|
||||
np.testing.assert_array_equal(output_buffer, np.asarray(sda))
|
||||
|
@ -759,14 +759,14 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
|
||||
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testShardedDeviceArrays(self):
|
||||
def testArrays(self):
|
||||
f = lambda x: 2 * x
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
# test that we can pass in and out ShardedDeviceArrays
|
||||
# test that we can pass in and out Arrays
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jax.Array)
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
@ -782,7 +782,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
# test that we can pass a ShardedDeviceArray to a regular jit computation
|
||||
# test that we can pass an Array to a regular jit computation
|
||||
z = y + y
|
||||
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
|
||||
|
||||
@ -809,7 +809,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
for in_shape, out_shape in [
|
||||
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
|
||||
])
|
||||
def testShardedDeviceArrayReshape(self, in_shape, out_shape):
|
||||
def testArrayReshape(self, in_shape, out_shape):
|
||||
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
|
||||
raise SkipTest("not enough devices")
|
||||
|
||||
@ -1551,7 +1551,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
1, 2).reshape(shape)
|
||||
self.assertAllClose(fn(x, w), expected, check_dtypes=False)
|
||||
|
||||
def testShardedDeviceArrayBlockUntilReady(self):
|
||||
def testArrayBlockUntilReady(self):
|
||||
x = np.arange(jax.device_count())
|
||||
x = self.pmap(lambda x: x)(x)
|
||||
x.block_until_ready() # doesn't crash
|
||||
@ -1602,7 +1602,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
multi_step_pmap(jnp.zeros((device_count,)), count=1)
|
||||
|
||||
def testShardedDeviceArrayGetItem(self):
|
||||
def testArrayGetItem(self):
|
||||
f = lambda x: 2 * x
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
@ -2612,7 +2612,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
class ArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def testThreadsafeIndexing(self):
|
||||
# NOTE(skye): I picked these values to be big enough to cause interesting
|
||||
@ -2868,7 +2868,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
def device_array(x):
|
||||
return jax.device_put(x)
|
||||
|
||||
# TODO(skye): add coverage for ShardedDeviceArrays
|
||||
# TODO(skye): add coverage for Arrays
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name":
|
||||
|
Loading…
x
Reference in New Issue
Block a user