mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rename the concrete class Array
to ArrayImpl
PiperOrigin-RevId: 477017236
This commit is contained in:
parent
71bcabe499
commit
cbf34cb609
@ -497,7 +497,7 @@ def _cpp_jit_clear_cache(self):
|
||||
|
||||
def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
|
||||
use_fastpath = (
|
||||
xc._version >= 92 and
|
||||
xc._version >= 96 and
|
||||
# This is if we have already executed this code-path (most-recent entry
|
||||
# has been reset to None). Thus, we do not support the fast-path.
|
||||
execute is not None and
|
||||
@ -506,7 +506,7 @@ def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
|
||||
not execute.ordered_effects and
|
||||
not execute.has_unordered_effects and
|
||||
not execute.has_host_callbacks and
|
||||
all(isinstance(x, xc.Array) for x in out_flat) and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
|
||||
# Not supported: dynamic shapes
|
||||
not jax.config.jax_dynamic_shapes
|
||||
# TODO(chky): Check sharding is SingleDeviceSharding
|
||||
@ -2215,7 +2215,7 @@ def _cpp_pmap(
|
||||
# sufficient, but we use the more general `DeviceArray`.
|
||||
all(
|
||||
isinstance(x, device_array.DeviceArray) or
|
||||
xc._version >= 94 and isinstance(x, xc.Array) for x in out_flat))
|
||||
xc._version >= 96 and isinstance(x, xc.ArrayImpl) for x in out_flat))
|
||||
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
@ -2871,7 +2871,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
if config.jax_array:
|
||||
from jax.experimental import array, sharding
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
return array.Array(
|
||||
return array.ArrayImpl(
|
||||
stacked_aval,
|
||||
sharding.PmapSharding(np.array(devices), sharding_spec),
|
||||
buffers, committed=True, _skip_checks=True)
|
||||
@ -2926,7 +2926,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
if config.jax_array:
|
||||
from jax.experimental import array, sharding
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
return array.Array(
|
||||
return array.ArrayImpl(
|
||||
aval, sharding.PmapSharding(np.array(devices), sharding_spec),
|
||||
[buf, *rest_bufs], committed=True, _skip_checks=True)
|
||||
else:
|
||||
|
@ -57,8 +57,6 @@ import jax._src.util as util
|
||||
from jax._src.util import flatten, unflatten
|
||||
from etils import epath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jax.experimental.array import Array
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -750,10 +748,10 @@ class SimpleResultHandler:
|
||||
|
||||
def maybe_create_array_from_da(buf, aval, device):
|
||||
if config.jax_array:
|
||||
from jax.experimental.array import Array
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax.experimental.sharding import SingleDeviceSharding
|
||||
return Array(aval, SingleDeviceSharding(buf.device()), [buf],
|
||||
committed=(device is not None), _skip_checks=True)
|
||||
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
|
||||
committed=(device is not None), _skip_checks=True)
|
||||
else:
|
||||
return device_array.make_device_array(aval, device, buf)
|
||||
|
||||
@ -1222,7 +1220,7 @@ def _copy_device_array_to_device(
|
||||
return device_array.make_device_array(x.aval, device, moved_buf)
|
||||
|
||||
|
||||
def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
|
||||
def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Array:
|
||||
"""Copies `Array`s with SingleDeviceSharding to a different device."""
|
||||
from jax.experimental import array, sharding
|
||||
|
||||
@ -1246,7 +1244,7 @@ def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
|
||||
# buffers from different XLA backends are passed through the host.
|
||||
backend = xb.get_device_backend(device)
|
||||
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
|
||||
return array.Array(
|
||||
return array.ArrayImpl(
|
||||
x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf],
|
||||
committed=(device is not None))
|
||||
|
||||
@ -1257,7 +1255,7 @@ def _device_put_impl(x, device: Optional[Device] = None):
|
||||
if device_array.type_is_device_array(x):
|
||||
return _copy_device_array_to_device(x, device)
|
||||
|
||||
if type(x) is array.Array and isinstance(x.sharding, sharding.SingleDeviceSharding):
|
||||
if type(x) is array.ArrayImpl and isinstance(x.sharding, sharding.SingleDeviceSharding):
|
||||
return _copy_array_to_device(x, device)
|
||||
|
||||
try:
|
||||
|
@ -40,10 +40,10 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
|
||||
owns.
|
||||
"""
|
||||
from jax.experimental import array
|
||||
if not isinstance(x, (device_array.DeviceArray, array.Array)):
|
||||
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, got {}"
|
||||
.format(type(x)))
|
||||
if isinstance(x, array.Array):
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
assert len(x._arrays) == 1
|
||||
buf = x._arrays[0]
|
||||
else:
|
||||
|
@ -86,7 +86,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
def _is_array_or_tracer(operand: Any) -> bool:
|
||||
if config.jax_array:
|
||||
from jax.experimental import array # pylint: disable=g-import-not-at-top
|
||||
return isinstance(operand, (core.Tracer, array.Array))
|
||||
return isinstance(operand, (core.Tracer, array.ArrayImpl))
|
||||
else:
|
||||
return isinstance(operand, (core.Tracer, device_array.DeviceArray))
|
||||
|
||||
@ -944,8 +944,6 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array:
|
||||
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
||||
operator.
|
||||
"""
|
||||
from jax.experimental import array
|
||||
|
||||
permutation = tuple(operator.index(d) for d in permutation)
|
||||
if permutation == tuple(range(np.ndim(operand))) and _is_array_or_tracer(operand):
|
||||
return type_cast(Array, operand)
|
||||
|
@ -78,7 +78,7 @@ from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
from jax.experimental.array import Array
|
||||
from jax.experimental.array import ArrayImpl
|
||||
|
||||
newaxis = None
|
||||
|
||||
@ -1876,7 +1876,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
|
||||
# We can't use the ndarray class because we need to handle internal buffers
|
||||
# (See https://github.com/google/jax/issues/8950)
|
||||
ndarray_types = (device_array.DeviceArray, core.Tracer, Array)
|
||||
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)
|
||||
|
||||
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
|
||||
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
|
||||
@ -4772,7 +4772,7 @@ _NOT_IMPLEMENTED = ['argpartition']
|
||||
|
||||
# Experimental support for NumPy's module dispatch with NEP-37.
|
||||
# Currently requires https://github.com/seberg/numpy-dispatch
|
||||
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, Array)
|
||||
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl)
|
||||
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
|
||||
|
||||
def __array_module__(self, types):
|
||||
@ -4811,7 +4811,7 @@ def _multi_slice(arr,
|
||||
def _unstack(x):
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
setattr(device_array.DeviceArray, "_unstack", _unstack)
|
||||
setattr(Array, '_unstack', _unstack)
|
||||
setattr(ArrayImpl, '_unstack', _unstack)
|
||||
|
||||
def _chunk_iter(x, size):
|
||||
if size > x.shape[0]:
|
||||
@ -4823,7 +4823,7 @@ def _chunk_iter(x, size):
|
||||
if tail:
|
||||
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
|
||||
setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter)
|
||||
setattr(Array, '_chunk_iter', _chunk_iter)
|
||||
setattr(ArrayImpl, '_chunk_iter', _chunk_iter)
|
||||
|
||||
# Syntactic sugar for scatter operations.
|
||||
class _IndexUpdateHelper:
|
||||
@ -5150,7 +5150,7 @@ def _set_device_array_base_attributes(device_array, include=None, exclude=None):
|
||||
maybe_setattr("clip", _clip)
|
||||
|
||||
_set_device_array_base_attributes(device_array.DeviceArray)
|
||||
_set_device_array_base_attributes(Array, exclude={'__getitem__'})
|
||||
_set_device_array_base_attributes(ArrayImpl, exclude={'__getitem__'})
|
||||
|
||||
|
||||
def _set_device_array_attributes(device_array):
|
||||
@ -5167,4 +5167,4 @@ for t in device_array.device_array_types:
|
||||
_set_device_array_attributes(t)
|
||||
_set_device_array_attributes(pxla._ShardedDeviceArray)
|
||||
_set_device_array_attributes(pxla.pmap_lib.ShardedDeviceArray)
|
||||
_set_device_array_attributes(Array)
|
||||
_set_device_array_attributes(ArrayImpl)
|
||||
|
@ -18,6 +18,7 @@ import operator as op
|
||||
import numpy as np
|
||||
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import ad_util
|
||||
@ -55,7 +56,7 @@ class Shard:
|
||||
"""
|
||||
|
||||
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
|
||||
data: Optional[Array] = None):
|
||||
data: Optional[ArrayImpl] = None):
|
||||
self.device = device
|
||||
self._sharding = sharding
|
||||
self._global_shape = global_shape
|
||||
@ -96,12 +97,12 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
|
||||
|
||||
def _single_device_array_from_buf(buf, committed):
|
||||
db = pxla._set_aval(buf)
|
||||
return Array(db.aval, SingleDeviceSharding(db.device()), [db],
|
||||
committed=committed, _skip_checks=True)
|
||||
return ArrayImpl(db.aval, SingleDeviceSharding(db.device()), [db],
|
||||
committed=committed, _skip_checks=True)
|
||||
|
||||
|
||||
@pxla.use_cpp_class(xc.Array if xc._version >= 92 else None)
|
||||
class Array(basearray.Array):
|
||||
@pxla.use_cpp_class(xc.ArrayImpl if xc._version >= 96 else None)
|
||||
class ArrayImpl(basearray.Array):
|
||||
"""Experimental unified Array type.
|
||||
|
||||
This Python implementation will eventually be replaced by a C++ implementation.
|
||||
@ -117,7 +118,7 @@ class Array(basearray.Array):
|
||||
|
||||
@pxla.use_cpp_method
|
||||
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
|
||||
arrays: Union[Sequence[DeviceArray], Sequence[Array]],
|
||||
arrays: Union[Sequence[DeviceArray], Sequence[ArrayImpl]],
|
||||
committed: bool, _skip_checks: bool = False):
|
||||
# NOTE: the actual implementation of the constructor is moved to C++.
|
||||
|
||||
@ -271,8 +272,8 @@ class Array(basearray.Array):
|
||||
if buf_idx is not None:
|
||||
buf = self._arrays[buf_idx]
|
||||
aval = core.ShapedArray(buf.shape, self.dtype)
|
||||
return Array(aval, SingleDeviceSharding(buf.device()), [buf],
|
||||
committed=False, _skip_checks=True)
|
||||
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
|
||||
committed=False, _skip_checks=True)
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
else:
|
||||
# TODO(yashkatariya): Don't bounce to host and use `_rewriting_take` or
|
||||
@ -385,7 +386,7 @@ class Array(basearray.Array):
|
||||
return [_single_device_array_from_buf(a, self._committed)
|
||||
for a in self._arrays]
|
||||
|
||||
def addressable_data(self, index: int) -> Array:
|
||||
def addressable_data(self, index: int) -> ArrayImpl:
|
||||
self._check_if_deleted()
|
||||
return _single_device_array_from_buf(self._arrays[index], self._committed)
|
||||
|
||||
@ -474,12 +475,12 @@ class Array(basearray.Array):
|
||||
return cast(np.ndarray, self._npy_value)
|
||||
|
||||
# explicitly set to be unhashable. Same as what device_array.py does.
|
||||
setattr(Array, "__hash__", None)
|
||||
setattr(Array, "__array_priority__", 100)
|
||||
setattr(ArrayImpl, "__hash__", None)
|
||||
setattr(ArrayImpl, "__array_priority__", 100)
|
||||
|
||||
|
||||
def make_array_from_callback(shape: Shape, sharding: Sharding,
|
||||
data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array:
|
||||
def make_array_from_callback(
|
||||
shape: Shape, sharding: Sharding,
|
||||
data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl:
|
||||
device_to_index_map = sharding.devices_indices_map(shape)
|
||||
# Use addressable_devices here instead of `_addressable_device_assignment`
|
||||
# because `_addressable_device_assignment` is only available on
|
||||
@ -490,32 +491,32 @@ def make_array_from_callback(shape: Shape, sharding: Sharding,
|
||||
for device in sharding.addressable_devices
|
||||
]
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
return Array(aval, sharding, arrays, committed=True)
|
||||
return ArrayImpl(aval, sharding, arrays, committed=True)
|
||||
|
||||
|
||||
def make_array_from_single_device_arrays(shape: Shape, sharding: Sharding,
|
||||
arrays: Sequence[Array]) -> Array:
|
||||
def make_array_from_single_device_arrays(
|
||||
shape: Shape, sharding: Sharding, arrays: Sequence[ArrayImpl]) -> ArrayImpl:
|
||||
# All input arrays should be committed. Checking it is expensive on
|
||||
# single-controller systems.
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
return Array(aval, sharding, arrays, committed=True)
|
||||
return ArrayImpl(aval, sharding, arrays, committed=True)
|
||||
|
||||
|
||||
core.pytype_aval_mappings[Array] = abstract_arrays.canonical_concrete_aval
|
||||
xla.pytype_aval_mappings[Array] = op.attrgetter('aval')
|
||||
xla.canonicalize_dtype_handlers[Array] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[Array] = op.attrgetter('aval')
|
||||
ad_util.jaxval_adders[Array] = lax_internal.add
|
||||
ad_util.jaxval_zeros_likers[Array] = lax_internal.zeros_like_array
|
||||
if xc._version >= 92:
|
||||
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
|
||||
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
|
||||
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval')
|
||||
ad_util.jaxval_adders[ArrayImpl] = lax_internal.add
|
||||
ad_util.jaxval_zeros_likers[ArrayImpl] = lax_internal.zeros_like_array
|
||||
if xc._version >= 96:
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
basearray.Array.register(Array)
|
||||
basearray.Array.register(ArrayImpl)
|
||||
|
||||
|
||||
def _array_mlir_constant_handler(val, canonicalize_types=True):
|
||||
return mlir.ir_constants(val._value,
|
||||
canonicalize_types=canonicalize_types)
|
||||
mlir.register_constant_handler(Array, _array_mlir_constant_handler)
|
||||
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
|
||||
|
||||
|
||||
def _device_put_array(x, device: Optional[Device]):
|
||||
@ -529,7 +530,7 @@ def _device_put_array(x, device: Optional[Device]):
|
||||
# Round trip via host if x is sharded. SDA also does a round trip via host.
|
||||
return dispatch._device_put_array(x._value, device)
|
||||
|
||||
dispatch.device_put_handlers[Array] = _device_put_array
|
||||
dispatch.device_put_handlers[ArrayImpl] = _device_put_array
|
||||
|
||||
|
||||
def _array_pmap_shard_arg(x, devices, indices, mode):
|
||||
@ -586,7 +587,7 @@ def _array_shard_arg(x, devices, indices, mode):
|
||||
return _array_pmap_shard_arg(x, devices, indices, mode)
|
||||
else:
|
||||
return _array_rest_shard_arg(x, devices, indices, mode)
|
||||
pxla.shard_arg_handlers[Array] = _array_shard_arg
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed,
|
||||
@ -596,8 +597,8 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
|
||||
if core.is_opaque_dtype(global_aval.dtype):
|
||||
return global_aval.dtype._rules.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed, is_out_sharding_from_xla)
|
||||
return lambda bufs: Array(global_aval, out_sharding, bufs,
|
||||
committed=committed, _skip_checks=True)
|
||||
return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs,
|
||||
committed=committed, _skip_checks=True)
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
|
||||
@ -608,7 +609,7 @@ def _array_local_result_handler(aval, sharding, indices):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.local_sharded_result_handler(
|
||||
aval, sharding, indices)
|
||||
return lambda bufs: Array(aval, sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
|
@ -55,8 +55,8 @@ async def create_async_array_from_callback(
|
||||
|
||||
dbs = [jax.device_put(array, device)
|
||||
for array, device in zip(local_arrays, addressable_da)]
|
||||
aval = jax.ShapedArray(global_shape, dbs[0].dtype)
|
||||
return array.Array(aval, inp_sharding, dbs, committed=True)
|
||||
return array.make_array_from_single_device_arrays(
|
||||
global_shape, inp_sharding, dbs)
|
||||
|
||||
|
||||
async def create_async_gda_from_callback(
|
||||
@ -86,7 +86,7 @@ def _get_metadata(arr):
|
||||
dtype = 'bfloat16'
|
||||
else:
|
||||
dtype = np.dtype(arr.dtype).str
|
||||
if isinstance(arr, array.Array):
|
||||
if isinstance(arr, array.ArrayImpl):
|
||||
local_shape = arr._arrays[0].shape
|
||||
else:
|
||||
local_shape = arr.local_data(0).shape
|
||||
@ -150,7 +150,7 @@ class _LimitInFlightBytes:
|
||||
|
||||
|
||||
async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
|
||||
if (isinstance(arr_inp, array.Array) and jax.process_count() > 1 and
|
||||
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable()):
|
||||
raise ValueError('Passing fully addressable Arrays to a multi-host '
|
||||
'serialization is not allowed.')
|
||||
@ -187,7 +187,7 @@ async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
|
||||
else:
|
||||
await write_future.commit
|
||||
|
||||
if isinstance(arr_inp, array.Array):
|
||||
if isinstance(arr_inp, array.ArrayImpl):
|
||||
local_shards = arr_inp.addressable_shards
|
||||
else:
|
||||
local_shards = arr_inp.local_shards
|
||||
|
@ -134,7 +134,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
[pspec, P('x'), P(None)],
|
||||
tspecs)
|
||||
|
||||
self.assertIsInstance(m1, array.Array)
|
||||
self.assertIsInstance(m1, array.ArrayImpl)
|
||||
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
|
||||
@ -142,7 +142,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertIsInstance(m2, array.Array)
|
||||
self.assertIsInstance(m2, array.ArrayImpl)
|
||||
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
|
||||
@ -150,7 +150,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
|
||||
self.assertIsInstance(m3, array.Array)
|
||||
self.assertIsInstance(m3, array.ArrayImpl)
|
||||
for i, s in enumerate(m3.addressable_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
|
@ -39,7 +39,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.experimental.array import Array
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax.interpreters import mlir
|
||||
@ -1826,7 +1826,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
axis_resources, resource_env, global_axis_sizes,
|
||||
in_positional_semantics).to_mesh_axes(in_axes_flat)
|
||||
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
|
||||
if isinstance(arg, (GlobalDeviceArray, Array)):
|
||||
if isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
||||
arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array'
|
||||
if arr_flavor == 'Array' and not isinstance(arg.sharding, MeshPspecSharding):
|
||||
continue
|
||||
|
@ -169,7 +169,7 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
|
||||
not executable.unsafe_call.has_unordered_effects and
|
||||
not executable.unsafe_call.has_host_callbacks and
|
||||
all(isinstance(x, xc.Array) for x in out_flat)
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat)
|
||||
)
|
||||
|
||||
if use_fastpath:
|
||||
@ -428,7 +428,7 @@ def pjit(fun: Callable,
|
||||
return (args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 95:
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
|
||||
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params)
|
||||
|
@ -873,10 +873,10 @@ def _hashable_index(idx):
|
||||
# The fast path is handled directly in shard_args().
|
||||
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
|
||||
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
|
||||
from jax.experimental.array import Array
|
||||
from jax.experimental.array import ArrayImpl
|
||||
|
||||
candidates = defaultdict(list)
|
||||
if isinstance(x, Array):
|
||||
if isinstance(x, ArrayImpl):
|
||||
bufs = x._arrays
|
||||
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
|
||||
else:
|
||||
@ -1002,7 +1002,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
||||
with jax.disable_jit(False):
|
||||
donate_argnums_ = donate_argnums
|
||||
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.Array)):
|
||||
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.ArrayImpl)):
|
||||
# We don't want to donate if it's already sharded.
|
||||
donate_argnums_ = ()
|
||||
out = jax.pmap(
|
||||
@ -3349,7 +3349,7 @@ def _out_shardings_for_trivial(
|
||||
device_assignment, sharding._get_replicated_op_sharding())
|
||||
shardings: Dict[core.Var, sharding.XLACompatibleSharding] = {}
|
||||
for constvar, constval in zip(jaxpr.constvars, consts):
|
||||
if isinstance(constval, array.Array):
|
||||
if isinstance(constval, array.ArrayImpl):
|
||||
shardings[constvar] = constval.sharding
|
||||
map(shardings.setdefault, jaxpr.invars, in_shardings)
|
||||
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
|
||||
@ -3374,7 +3374,7 @@ def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
||||
|
||||
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.array import Array
|
||||
from jax.experimental.array import ArrayImpl
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
|
||||
@ -3387,7 +3387,7 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
f"xla sharding: {in_xla_sharding}")
|
||||
|
||||
for arg, xs in safe_zip(args, in_xla_shardings):
|
||||
if not isinstance(arg, (GlobalDeviceArray, Array)):
|
||||
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
||||
continue
|
||||
if isinstance(arg, GlobalDeviceArray):
|
||||
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,
|
||||
|
@ -79,7 +79,7 @@ numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
def _check_instance(self, x):
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(x, array.Array)
|
||||
self.assertIsInstance(x, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(x, device_array.DeviceArray)
|
||||
|
||||
|
@ -191,7 +191,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_jnp_array(self):
|
||||
arr = jnp.array([1, 2, 3])
|
||||
self.assertIsInstance(arr, array.Array)
|
||||
self.assertIsInstance(arr, array.ArrayImpl)
|
||||
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
|
||||
self.assertEqual(arr._committed, False)
|
||||
|
||||
@ -199,13 +199,13 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
a = jnp.array([1, 2, 3])
|
||||
b = jnp.array([4, 5, 6])
|
||||
arr = jax.jit(lambda x, y: x + y)(a, b)
|
||||
self.assertIsInstance(arr, array.Array)
|
||||
self.assertIsInstance(arr, array.ArrayImpl)
|
||||
self.assertArraysEqual(arr, np.array([5, 7, 9]))
|
||||
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
|
||||
|
||||
def test_jnp_array_jnp_add(self):
|
||||
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
|
||||
self.assertIsInstance(arr, array.Array)
|
||||
self.assertIsInstance(arr, array.ArrayImpl)
|
||||
self.assertArraysEqual(arr, np.array([5, 7, 9]))
|
||||
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
|
||||
|
||||
@ -213,7 +213,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
a = jnp.array([1, 2, 3])
|
||||
b = jnp.array([4, 5, 6])
|
||||
arr = a + b
|
||||
self.assertIsInstance(arr, array.Array)
|
||||
self.assertIsInstance(arr, array.ArrayImpl)
|
||||
self.assertArraysEqual(arr, np.array([5, 7, 9]))
|
||||
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
|
||||
|
||||
@ -279,13 +279,13 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
r'Expected 8 per-device arrays \(this is how many devices are addressable '
|
||||
r'by the sharding\), but got 4'):
|
||||
array.Array(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Expected 8 per-device arrays \(this is how many devices are addressable '
|
||||
r'by the sharding\), but got 16'):
|
||||
array.Array(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
|
||||
|
||||
def test_arrays_not_in_device_assignment(self):
|
||||
if jax.device_count() < 4:
|
||||
@ -299,7 +299,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
"Some per-device arrays are placed on devices {2, 3}, which are "
|
||||
"not used in the specified sharding"):
|
||||
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"), (2, 2)),
|
||||
@ -334,7 +334,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
"Input buffers to `Array` must have matching dtypes. "
|
||||
"Got int32, expected float32"):
|
||||
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_array_iter_pmap_sharding(self):
|
||||
if jax.device_count() < 2:
|
||||
@ -347,7 +347,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
sin_x = iter(np.sin(x))
|
||||
for i, j in zip(iter(y), sin_x):
|
||||
self.assertIsInstance(i, array.Array)
|
||||
self.assertIsInstance(i, array.ArrayImpl)
|
||||
self.assertArraysAllClose(i, j)
|
||||
|
||||
def test_array_iter_pmap_sharding_last_dim_sharded(self):
|
||||
@ -433,10 +433,10 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
|
||||
|
||||
for a in arr.device_buffers:
|
||||
self.assertIsInstance(a, array.Array)
|
||||
self.assertIsInstance(a, array.ArrayImpl)
|
||||
|
||||
x = jnp.array([1, 2, 3])
|
||||
self.assertIsInstance(x.device_buffer, array.Array)
|
||||
self.assertIsInstance(x.device_buffer, array.ArrayImpl)
|
||||
|
||||
|
||||
@jtu.with_config(jax_array=True)
|
||||
@ -539,8 +539,8 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_is_subclass(self):
|
||||
# array version of api_test.py::APITest::test_is_subclass
|
||||
self.assertTrue(issubclass(array.Array, jnp.ndarray))
|
||||
self.assertFalse(issubclass(array.Array, np.ndarray))
|
||||
self.assertTrue(issubclass(array.ArrayImpl, jnp.ndarray))
|
||||
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -26,7 +26,6 @@ import numpy as np
|
||||
import jax
|
||||
from jax._src import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import array
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -76,13 +75,6 @@ def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def _check_instance(self, x):
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(x, array.Array)
|
||||
else:
|
||||
self.assertIsInstance(x, jnp.DeviceArray)
|
||||
|
||||
|
||||
class DtypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_canonicalize_type(self):
|
||||
@ -231,7 +223,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
def testScalarInstantiation(self, scalar_type):
|
||||
a = scalar_type(1)
|
||||
self.assertEqual(a.dtype, jnp.dtype(scalar_type))
|
||||
_check_instance(self, a)
|
||||
self.assertIsInstance(a, jax.Array)
|
||||
self.assertEqual(0, jnp.ndim(a))
|
||||
self.assertIsInstance(np.dtype(scalar_type).type(1), scalar_type)
|
||||
|
||||
|
@ -34,7 +34,6 @@ from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import array
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
@ -2464,10 +2463,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
|
||||
_, vjp_fun = jax.vjp(cumprod, x)
|
||||
*_, ext_res = vjp_fun.args[0].args[0]
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(ext_res, array.Array)
|
||||
else:
|
||||
self.assertIsInstance(ext_res, jnp.DeviceArray)
|
||||
self.assertIsInstance(ext_res, jax.Array)
|
||||
|
||||
def test_scan_vmap_collectives(self):
|
||||
def scan_f(state, x):
|
||||
|
@ -2783,7 +2783,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
out = jnp.concatenate([np_input])
|
||||
if config.jax_array:
|
||||
self.assertIs(type(out), array.Array)
|
||||
self.assertIs(type(out), array.ArrayImpl)
|
||||
else:
|
||||
self.assertTrue(device_array.type_is_device_array(out))
|
||||
|
||||
@ -4154,8 +4154,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def testArrayOutputsDeviceArrays(self):
|
||||
if config.jax_array:
|
||||
assert type(jnp.array([])) is array.Array
|
||||
assert type(jnp.array(np.array([]))) is array.Array
|
||||
assert type(jnp.array([])) is array.ArrayImpl
|
||||
assert type(jnp.array(np.array([]))) is array.ArrayImpl
|
||||
else:
|
||||
assert device_array.type_is_device_array(jnp.array([]))
|
||||
assert device_array.type_is_device_array(jnp.array(np.array([])))
|
||||
@ -4164,7 +4164,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def __array__(self, dtype=None):
|
||||
return np.array([], dtype=dtype)
|
||||
if config.jax_array:
|
||||
assert type(jnp.array(NDArrayLike())) is array.Array
|
||||
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl
|
||||
else:
|
||||
assert device_array.type_is_device_array(jnp.array(NDArrayLike()))
|
||||
|
||||
|
@ -1452,8 +1452,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if jit:
|
||||
op = jax.jit(op)
|
||||
result = op(operand)
|
||||
expected_type = array.Array if config.jax_array else jnp.DeviceArray
|
||||
self.assertIsInstance(result, expected_type)
|
||||
self.assertIsInstance(result, jax.Array)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}".format(
|
||||
@ -2915,10 +2914,7 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
if jit:
|
||||
op = jax.jit(op)
|
||||
result = op(input_type(value))
|
||||
if config.jax_array:
|
||||
assert isinstance(result, array.Array)
|
||||
else:
|
||||
assert isinstance(result, jnp.DeviceArray)
|
||||
assert isinstance(result, jax.Array)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype_in={}_dtype_out={}".format(
|
||||
@ -3212,13 +3208,13 @@ class FooArray:
|
||||
ndim = property(lambda self: self.data.ndim - 1)
|
||||
|
||||
def device_put_foo_array(x: FooArray, device):
|
||||
if isinstance(x.data, array.Array):
|
||||
if isinstance(x.data, array.ArrayImpl):
|
||||
return array._device_put_array(x.data, device)
|
||||
return dispatch._device_put_array(x.data, device)
|
||||
|
||||
def shard_foo_array_handler(x, devices, indices, mode):
|
||||
device, = devices
|
||||
if isinstance(x.data, array.Array):
|
||||
if isinstance(x.data, array.ArrayImpl):
|
||||
return array._device_put_array(x.data, device)
|
||||
return dispatch._device_put_array(x.data, device)
|
||||
|
||||
|
@ -21,7 +21,6 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.experimental import array
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
|
||||
@ -198,10 +197,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
devices = self.get_devices()
|
||||
|
||||
def f(): return lax.add(3., 4.)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(f(), array.Array)
|
||||
else:
|
||||
self.assertIsInstance(f(), jnp.DeviceArray)
|
||||
self.assertIsInstance(f(), jax.Array)
|
||||
self.assert_uncommitted_to_device(f(), devices[0])
|
||||
self.assert_uncommitted_to_device(jax.jit(f)(), devices[0])
|
||||
self.assert_committed_to_device(jax.jit(f, device=devices[1])(),
|
||||
|
@ -113,7 +113,7 @@ def simulated_cached_fun(s):
|
||||
|
||||
def _check_instance(self, x):
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(x, array.Array)
|
||||
self.assertIsInstance(x, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(x, pxla.ShardedDeviceArray)
|
||||
|
||||
@ -404,7 +404,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
expected = (x + 1) * 2
|
||||
actual = f(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, array.Array)
|
||||
self.assertIsInstance(actual, array.ArrayImpl)
|
||||
self.assertLen(actual.addressable_shards, 2)
|
||||
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
|
||||
check_dtypes=False)
|
||||
@ -433,7 +433,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
expected = (x + 1) * 2
|
||||
actual = f(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, array.Array)
|
||||
self.assertIsInstance(actual, array.ArrayImpl)
|
||||
self.assertLen(actual.addressable_shards, 2)
|
||||
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
|
||||
check_dtypes=False)
|
||||
@ -1539,7 +1539,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings]
|
||||
out = compiled(*inputs)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
|
||||
@ -1646,7 +1646,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
for ip in compiled.input_shardings]
|
||||
out1, out2 = compiled(*inputs)
|
||||
for o in [out1, out2]:
|
||||
self.assertIsInstance(o, array.Array)
|
||||
self.assertIsInstance(o, array.ArrayImpl)
|
||||
self.assertArraysEqual(o._value, input_data)
|
||||
|
||||
@unittest.skip('The error is not raised yet. Enable this back once we raise '
|
||||
@ -1696,7 +1696,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
@ -1723,7 +1723,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
@ -1744,7 +1744,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# Since no in_axis_resources is provided, pjit will assume that
|
||||
# the numpy input is fully replicated over the mesh.
|
||||
out = f(input_data)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
@ -1763,7 +1763,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out_axis_resources=MeshPspecSharding(
|
||||
global_mesh, P('x', 'y')))
|
||||
out = f(input_data)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
@ -1773,7 +1773,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_unspecified_out_axis_resources(self):
|
||||
|
||||
def _checks(out, input_data):
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertIsInstance(out.sharding, OpShardingSharding)
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
@ -1826,25 +1826,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out_tree = f((a1, (a2, (a3, a4))))
|
||||
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
|
||||
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertIsInstance(out1, array.ArrayImpl)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out3, array.Array)
|
||||
self.assertIsInstance(out3, array.ArrayImpl)
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
|
||||
for s in out3.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out4, array.Array)
|
||||
self.assertIsInstance(out4, array.ArrayImpl)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
|
||||
for s in out4.addressable_shards:
|
||||
@ -1879,7 +1879,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = pjit(
|
||||
lambda x: x,
|
||||
in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
|
||||
def test_no_input_output(self):
|
||||
with jax_array(True):
|
||||
@ -1918,7 +1918,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))
|
||||
compiled = f.lower(aval, aval).compile()
|
||||
out = compiled(a1, a1)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out._value, input_data @ input_data.T)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -2045,11 +2045,11 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def add(x, y):
|
||||
return x + y
|
||||
out = add(a, b)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out, a + b)
|
||||
|
||||
out2 = add(out, out)
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertArraysEqual(out2, 2 * (a + b))
|
||||
|
||||
@jax_array(True)
|
||||
@ -2061,7 +2061,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return x @ x.T
|
||||
|
||||
out = mul(a)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out, a @ a.T)
|
||||
|
||||
@jax_array(True)
|
||||
@ -2107,7 +2107,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
a = jnp.array(16, dtype=jnp.float32)
|
||||
f = lambda x: x
|
||||
out = jax.grad(pjit(f))(a)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out, jax.grad(f)(a))
|
||||
|
||||
@jax_array(True)
|
||||
@ -2137,7 +2137,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# is correct.
|
||||
bufs = [jax.device_put(inp_data[s.device_indices(d, shape)], d)
|
||||
for d in jax.local_devices()]
|
||||
arr = array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
arr = array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
f = pjit(lambda x: x, out_axis_resources=s)
|
||||
out = f(arr)
|
||||
@ -2225,7 +2225,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
arr = array.Array(
|
||||
arr = array.ArrayImpl(
|
||||
jax.ShapedArray(shape, np.int32), MeshPspecSharding(mesh, P(None)),
|
||||
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -556,7 +556,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f_ans = f(x, y)
|
||||
self.assertAllClose(f_ans, f_expected)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(f_ans, array.Array)
|
||||
self.assertIsInstance(f_ans, array.ArrayImpl)
|
||||
sharding_spec = f_ans.sharding.sharding_spec
|
||||
else:
|
||||
self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
|
||||
@ -571,7 +571,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
g_ans = g(x, y)
|
||||
self.assertAllClose(g_ans, g_expected)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(g_ans, array.Array)
|
||||
self.assertIsInstance(g_ans, array.ArrayImpl)
|
||||
sharding_spec = g_ans.sharding.sharding_spec
|
||||
else:
|
||||
self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
|
||||
@ -722,7 +722,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jnp.ndarray)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(y, array.Array)
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
||||
self.assertIsInstance(y, device_array.DeviceArray)
|
||||
@ -730,7 +730,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
z = f(y)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(z, array.Array)
|
||||
self.assertIsInstance(z, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(z, pxla.ShardedDeviceArray)
|
||||
self.assertIsInstance(z, device_array.DeviceArray)
|
||||
@ -740,7 +740,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# test that we can pass in a regular DeviceArray
|
||||
y = f(device_put(x))
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(y, array.Array)
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
||||
self.assertIsInstance(y, device_array.DeviceArray)
|
||||
@ -1606,7 +1606,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jnp.ndarray)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(y, array.Array)
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
||||
|
||||
@ -2585,7 +2585,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsNone(sharded_x._npy_value)
|
||||
|
||||
if config.jax_array:
|
||||
arr_type = array.Array
|
||||
arr_type = array.ArrayImpl
|
||||
else:
|
||||
arr_type = device_array.DeviceArray
|
||||
|
||||
@ -2595,7 +2595,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
('array', True, array.ArrayImpl, '_arrays')
|
||||
)
|
||||
def test_device_put_sharded(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
@ -2611,7 +2611,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
('array', True, array.ArrayImpl, '_arrays')
|
||||
)
|
||||
def test_device_put_sharded_pytree(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
@ -2632,7 +2632,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
('array', True, array.ArrayImpl, '_arrays')
|
||||
)
|
||||
def test_device_put_replicated(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
@ -2648,7 +2648,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
('array', True, array.ArrayImpl, '_arrays')
|
||||
)
|
||||
def test_device_put_replicated_pytree(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
@ -2900,7 +2900,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
||||
expected = input_data * input_data
|
||||
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], expected[s.index])
|
||||
self.assertArraysEqual(out, expected)
|
||||
@ -2918,8 +2918,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
with jax_config.jax_array(True):
|
||||
out1, out2 = f(input_array, input_array)
|
||||
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertIsInstance(out1, array.ArrayImpl)
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
for s1, s2 in safe_zip(out1.addressable_shards, out2.addressable_shards):
|
||||
self.assertArraysEqual(s1.data._arrays[0], input_data[s1.index])
|
||||
self.assertArraysEqual(s2.data._arrays[0], input_data[s2.index])
|
||||
@ -2942,8 +2942,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
with jax_config.jax_array(True):
|
||||
out1, out2 = f(a1, a2)
|
||||
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertIsInstance(out1, array.ArrayImpl)
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertEqual(out1.shape, (2,))
|
||||
self.assertEqual(out2.shape, (dc, dc, 2))
|
||||
for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)):
|
||||
|
@ -27,7 +27,7 @@ from jax._src import typing
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.array import Array as ArrayImpl
|
||||
from jax.experimental.array import ArrayImpl
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
@ -1167,7 +1167,7 @@ class XMapArrayTest(XMapTestCase):
|
||||
axis_resources={"a": "x", "b": "y"})
|
||||
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertDictEqual(out.sharding.mesh.shape, {'x': 4, 'y': 2})
|
||||
@ -1191,14 +1191,14 @@ class XMapArrayTest(XMapTestCase):
|
||||
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
|
||||
out1, out2 = f(a1, a2)
|
||||
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertIsInstance(out1, array.ArrayImpl)
|
||||
self.assertEqual(out1.shape, (8,))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
|
||||
self.assertDictEqual(out1.sharding.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertEqual(out2.shape, (8,))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
|
||||
self.assertDictEqual(out2.sharding.mesh.shape, {'x': 4, 'y': 2})
|
||||
|
Loading…
x
Reference in New Issue
Block a user