Rename the concrete class Array to ArrayImpl

PiperOrigin-RevId: 477017236
This commit is contained in:
Yash Katariya 2022-09-26 16:17:26 -07:00 committed by jax authors
parent 71bcabe499
commit cbf34cb609
22 changed files with 140 additions and 163 deletions

View File

@ -497,7 +497,7 @@ def _cpp_jit_clear_cache(self):
def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
use_fastpath = (
xc._version >= 92 and
xc._version >= 96 and
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
@ -506,7 +506,7 @@ def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
not execute.ordered_effects and
not execute.has_unordered_effects and
not execute.has_host_callbacks and
all(isinstance(x, xc.Array) for x in out_flat) and
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes
# TODO(chky): Check sharding is SingleDeviceSharding
@ -2215,7 +2215,7 @@ def _cpp_pmap(
# sufficient, but we use the more general `DeviceArray`.
all(
isinstance(x, device_array.DeviceArray) or
xc._version >= 94 and isinstance(x, xc.Array) for x in out_flat))
xc._version >= 96 and isinstance(x, xc.ArrayImpl) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
@ -2871,7 +2871,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
if config.jax_array:
from jax.experimental import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return array.Array(
return array.ArrayImpl(
stacked_aval,
sharding.PmapSharding(np.array(devices), sharding_spec),
buffers, committed=True, _skip_checks=True)
@ -2926,7 +2926,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
if config.jax_array:
from jax.experimental import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(aval)
return array.Array(
return array.ArrayImpl(
aval, sharding.PmapSharding(np.array(devices), sharding_spec),
[buf, *rest_bufs], committed=True, _skip_checks=True)
else:

View File

@ -57,8 +57,6 @@ import jax._src.util as util
from jax._src.util import flatten, unflatten
from etils import epath
if TYPE_CHECKING:
from jax.experimental.array import Array
FLAGS = flags.FLAGS
@ -750,10 +748,10 @@ class SimpleResultHandler:
def maybe_create_array_from_da(buf, aval, device):
if config.jax_array:
from jax.experimental.array import Array
from jax.experimental.array import ArrayImpl
from jax.experimental.sharding import SingleDeviceSharding
return Array(aval, SingleDeviceSharding(buf.device()), [buf],
committed=(device is not None), _skip_checks=True)
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
committed=(device is not None), _skip_checks=True)
else:
return device_array.make_device_array(aval, device, buf)
@ -1222,7 +1220,7 @@ def _copy_device_array_to_device(
return device_array.make_device_array(x.aval, device, moved_buf)
def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Array:
"""Copies `Array`s with SingleDeviceSharding to a different device."""
from jax.experimental import array, sharding
@ -1246,7 +1244,7 @@ def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
return array.Array(
return array.ArrayImpl(
x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf],
committed=(device is not None))
@ -1257,7 +1255,7 @@ def _device_put_impl(x, device: Optional[Device] = None):
if device_array.type_is_device_array(x):
return _copy_device_array_to_device(x, device)
if type(x) is array.Array and isinstance(x.sharding, sharding.SingleDeviceSharding):
if type(x) is array.ArrayImpl and isinstance(x.sharding, sharding.SingleDeviceSharding):
return _copy_array_to_device(x, device)
try:

View File

@ -40,10 +40,10 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
owns.
"""
from jax.experimental import array
if not isinstance(x, (device_array.DeviceArray, array.Array)):
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, got {}"
.format(type(x)))
if isinstance(x, array.Array):
if isinstance(x, array.ArrayImpl):
assert len(x._arrays) == 1
buf = x._arrays[0]
else:

View File

@ -86,7 +86,7 @@ zip, unsafe_zip = safe_zip, zip
def _is_array_or_tracer(operand: Any) -> bool:
if config.jax_array:
from jax.experimental import array # pylint: disable=g-import-not-at-top
return isinstance(operand, (core.Tracer, array.Array))
return isinstance(operand, (core.Tracer, array.ArrayImpl))
else:
return isinstance(operand, (core.Tracer, device_array.DeviceArray))
@ -944,8 +944,6 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array:
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
operator.
"""
from jax.experimental import array
permutation = tuple(operator.index(d) for d in permutation)
if permutation == tuple(range(np.ndim(operand))) and _is_array_or_tracer(operand):
return type_cast(Array, operand)

View File

@ -78,7 +78,7 @@ from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis)
from jax.experimental.array import Array
from jax.experimental.array import ArrayImpl
newaxis = None
@ -1876,7 +1876,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# We can't use the ndarray class because we need to handle internal buffers
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer, Array)
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
@ -4772,7 +4772,7 @@ _NOT_IMPLEMENTED = ['argpartition']
# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, Array)
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl)
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
def __array_module__(self, types):
@ -4811,7 +4811,7 @@ def _multi_slice(arr,
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(device_array.DeviceArray, "_unstack", _unstack)
setattr(Array, '_unstack', _unstack)
setattr(ArrayImpl, '_unstack', _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
@ -4823,7 +4823,7 @@ def _chunk_iter(x, size):
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter)
setattr(Array, '_chunk_iter', _chunk_iter)
setattr(ArrayImpl, '_chunk_iter', _chunk_iter)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
@ -5150,7 +5150,7 @@ def _set_device_array_base_attributes(device_array, include=None, exclude=None):
maybe_setattr("clip", _clip)
_set_device_array_base_attributes(device_array.DeviceArray)
_set_device_array_base_attributes(Array, exclude={'__getitem__'})
_set_device_array_base_attributes(ArrayImpl, exclude={'__getitem__'})
def _set_device_array_attributes(device_array):
@ -5167,4 +5167,4 @@ for t in device_array.device_array_types:
_set_device_array_attributes(t)
_set_device_array_attributes(pxla._ShardedDeviceArray)
_set_device_array_attributes(pxla.pmap_lib.ShardedDeviceArray)
_set_device_array_attributes(Array)
_set_device_array_attributes(ArrayImpl)

View File

@ -18,6 +18,7 @@ import operator as op
import numpy as np
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
import jax
from jax import core
from jax._src import abstract_arrays
from jax._src import ad_util
@ -55,7 +56,7 @@ class Shard:
"""
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: Optional[Array] = None):
data: Optional[ArrayImpl] = None):
self.device = device
self._sharding = sharding
self._global_shape = global_shape
@ -96,12 +97,12 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
def _single_device_array_from_buf(buf, committed):
db = pxla._set_aval(buf)
return Array(db.aval, SingleDeviceSharding(db.device()), [db],
committed=committed, _skip_checks=True)
return ArrayImpl(db.aval, SingleDeviceSharding(db.device()), [db],
committed=committed, _skip_checks=True)
@pxla.use_cpp_class(xc.Array if xc._version >= 92 else None)
class Array(basearray.Array):
@pxla.use_cpp_class(xc.ArrayImpl if xc._version >= 96 else None)
class ArrayImpl(basearray.Array):
"""Experimental unified Array type.
This Python implementation will eventually be replaced by a C++ implementation.
@ -117,7 +118,7 @@ class Array(basearray.Array):
@pxla.use_cpp_method
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
arrays: Union[Sequence[DeviceArray], Sequence[Array]],
arrays: Union[Sequence[DeviceArray], Sequence[ArrayImpl]],
committed: bool, _skip_checks: bool = False):
# NOTE: the actual implementation of the constructor is moved to C++.
@ -271,8 +272,8 @@ class Array(basearray.Array):
if buf_idx is not None:
buf = self._arrays[buf_idx]
aval = core.ShapedArray(buf.shape, self.dtype)
return Array(aval, SingleDeviceSharding(buf.device()), [buf],
committed=False, _skip_checks=True)
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
committed=False, _skip_checks=True)
return lax_numpy._rewriting_take(self, idx)
else:
# TODO(yashkatariya): Don't bounce to host and use `_rewriting_take` or
@ -385,7 +386,7 @@ class Array(basearray.Array):
return [_single_device_array_from_buf(a, self._committed)
for a in self._arrays]
def addressable_data(self, index: int) -> Array:
def addressable_data(self, index: int) -> ArrayImpl:
self._check_if_deleted()
return _single_device_array_from_buf(self._arrays[index], self._committed)
@ -474,12 +475,12 @@ class Array(basearray.Array):
return cast(np.ndarray, self._npy_value)
# explicitly set to be unhashable. Same as what device_array.py does.
setattr(Array, "__hash__", None)
setattr(Array, "__array_priority__", 100)
setattr(ArrayImpl, "__hash__", None)
setattr(ArrayImpl, "__array_priority__", 100)
def make_array_from_callback(shape: Shape, sharding: Sharding,
data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array:
def make_array_from_callback(
shape: Shape, sharding: Sharding,
data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl:
device_to_index_map = sharding.devices_indices_map(shape)
# Use addressable_devices here instead of `_addressable_device_assignment`
# because `_addressable_device_assignment` is only available on
@ -490,32 +491,32 @@ def make_array_from_callback(shape: Shape, sharding: Sharding,
for device in sharding.addressable_devices
]
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
return Array(aval, sharding, arrays, committed=True)
return ArrayImpl(aval, sharding, arrays, committed=True)
def make_array_from_single_device_arrays(shape: Shape, sharding: Sharding,
arrays: Sequence[Array]) -> Array:
def make_array_from_single_device_arrays(
shape: Shape, sharding: Sharding, arrays: Sequence[ArrayImpl]) -> ArrayImpl:
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
return Array(aval, sharding, arrays, committed=True)
return ArrayImpl(aval, sharding, arrays, committed=True)
core.pytype_aval_mappings[Array] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[Array] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[Array] = pxla.identity
api_util._shaped_abstractify_handlers[Array] = op.attrgetter('aval')
ad_util.jaxval_adders[Array] = lax_internal.add
ad_util.jaxval_zeros_likers[Array] = lax_internal.zeros_like_array
if xc._version >= 92:
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval')
ad_util.jaxval_adders[ArrayImpl] = lax_internal.add
ad_util.jaxval_zeros_likers[ArrayImpl] = lax_internal.zeros_like_array
if xc._version >= 96:
# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(Array)
basearray.Array.register(ArrayImpl)
def _array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(val._value,
canonicalize_types=canonicalize_types)
mlir.register_constant_handler(Array, _array_mlir_constant_handler)
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
def _device_put_array(x, device: Optional[Device]):
@ -529,7 +530,7 @@ def _device_put_array(x, device: Optional[Device]):
# Round trip via host if x is sharded. SDA also does a round trip via host.
return dispatch._device_put_array(x._value, device)
dispatch.device_put_handlers[Array] = _device_put_array
dispatch.device_put_handlers[ArrayImpl] = _device_put_array
def _array_pmap_shard_arg(x, devices, indices, mode):
@ -586,7 +587,7 @@ def _array_shard_arg(x, devices, indices, mode):
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
return _array_rest_shard_arg(x, devices, indices, mode)
pxla.shard_arg_handlers[Array] = _array_shard_arg
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed,
@ -596,8 +597,8 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
return lambda bufs: Array(global_aval, out_sharding, bufs,
committed=committed, _skip_checks=True)
return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs,
committed=committed, _skip_checks=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
@ -608,7 +609,7 @@ def _array_local_result_handler(aval, sharding, indices):
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
return lambda bufs: Array(aval, sharding, bufs, committed=True,
_skip_checks=True)
return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True,
_skip_checks=True)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler

View File

@ -55,8 +55,8 @@ async def create_async_array_from_callback(
dbs = [jax.device_put(array, device)
for array, device in zip(local_arrays, addressable_da)]
aval = jax.ShapedArray(global_shape, dbs[0].dtype)
return array.Array(aval, inp_sharding, dbs, committed=True)
return array.make_array_from_single_device_arrays(
global_shape, inp_sharding, dbs)
async def create_async_gda_from_callback(
@ -86,7 +86,7 @@ def _get_metadata(arr):
dtype = 'bfloat16'
else:
dtype = np.dtype(arr.dtype).str
if isinstance(arr, array.Array):
if isinstance(arr, array.ArrayImpl):
local_shape = arr._arrays[0].shape
else:
local_shape = arr.local_data(0).shape
@ -150,7 +150,7 @@ class _LimitInFlightBytes:
async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
if (isinstance(arr_inp, array.Array) and jax.process_count() > 1 and
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
arr_inp.is_fully_addressable()):
raise ValueError('Passing fully addressable Arrays to a multi-host '
'serialization is not allowed.')
@ -187,7 +187,7 @@ async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
else:
await write_future.commit
if isinstance(arr_inp, array.Array):
if isinstance(arr_inp, array.ArrayImpl):
local_shards = arr_inp.addressable_shards
else:
local_shards = arr_inp.local_shards

View File

@ -134,7 +134,7 @@ class CheckpointTest(jtu.JaxTestCase):
[pspec, P('x'), P(None)],
tspecs)
self.assertIsInstance(m1, array.Array)
self.assertIsInstance(m1, array.ArrayImpl)
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
np.array([[0], [2]]))
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
@ -142,7 +142,7 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
self.assertIsInstance(m2, array.Array)
self.assertIsInstance(m2, array.ArrayImpl)
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
@ -150,7 +150,7 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
self.assertIsInstance(m3, array.Array)
self.assertIsInstance(m3, array.ArrayImpl)
for i, s in enumerate(m3.addressable_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)

View File

@ -39,7 +39,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax.experimental.array import Array
from jax.experimental.array import ArrayImpl
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.sharding import MeshPspecSharding
from jax.interpreters import mlir
@ -1826,7 +1826,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
axis_resources, resource_env, global_axis_sizes,
in_positional_semantics).to_mesh_axes(in_axes_flat)
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
if isinstance(arg, (GlobalDeviceArray, Array)):
if isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array'
if arr_flavor == 'Array' and not isinstance(arg.sharding, MeshPspecSharding):
continue

View File

@ -169,7 +169,7 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.Array) for x in out_flat)
all(isinstance(x, xc.ArrayImpl) for x in out_flat)
)
if use_fastpath:
@ -428,7 +428,7 @@ def pjit(fun: Callable,
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
if FLAGS.experimental_cpp_pjit and xc._version >= 95:
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
else:
wrapped = _python_pjit(fun, infer_params)

View File

@ -873,10 +873,10 @@ def _hashable_index(idx):
# The fast path is handled directly in shard_args().
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
from jax.experimental.array import Array
from jax.experimental.array import ArrayImpl
candidates = defaultdict(list)
if isinstance(x, Array):
if isinstance(x, ArrayImpl):
bufs = x._arrays
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
else:
@ -1002,7 +1002,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax.disable_jit(False):
donate_argnums_ = donate_argnums
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.Array)):
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.ArrayImpl)):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(
@ -3349,7 +3349,7 @@ def _out_shardings_for_trivial(
device_assignment, sharding._get_replicated_op_sharding())
shardings: Dict[core.Var, sharding.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.Array):
if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding
map(shardings.setdefault, jaxpr.invars, in_shardings)
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
@ -3374,7 +3374,7 @@ def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.array import Array
from jax.experimental.array import ArrayImpl
@lru_cache(maxsize=4096)
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
@ -3387,7 +3387,7 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
f"xla sharding: {in_xla_sharding}")
for arg, xs in safe_zip(args, in_xla_shardings):
if not isinstance(arg, (GlobalDeviceArray, Array)):
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
continue
if isinstance(arg, GlobalDeviceArray):
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,

View File

@ -79,7 +79,7 @@ numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
def _check_instance(self, x):
if config.jax_array:
self.assertIsInstance(x, array.Array)
self.assertIsInstance(x, array.ArrayImpl)
else:
self.assertIsInstance(x, device_array.DeviceArray)

View File

@ -191,7 +191,7 @@ class JaxArrayTest(jtu.JaxTestCase):
def test_jnp_array(self):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.Array)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
self.assertEqual(arr._committed, False)
@ -199,13 +199,13 @@ class JaxArrayTest(jtu.JaxTestCase):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.Array)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_jnp_array_jnp_add(self):
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.Array)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
@ -213,7 +213,7 @@ class JaxArrayTest(jtu.JaxTestCase):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.Array)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
@ -279,13 +279,13 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 4'):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 16'):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
@ -299,7 +299,7 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
"Some per-device arrays are placed on devices {2, 3}, which are "
"not used in the specified sharding"):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
@ -334,7 +334,7 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
@ -347,7 +347,7 @@ class JaxArrayTest(jtu.JaxTestCase):
sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):
self.assertIsInstance(i, array.Array)
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysAllClose(i, j)
def test_array_iter_pmap_sharding_last_dim_sharded(self):
@ -433,10 +433,10 @@ class JaxArrayTest(jtu.JaxTestCase):
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
for a in arr.device_buffers:
self.assertIsInstance(a, array.Array)
self.assertIsInstance(a, array.ArrayImpl)
x = jnp.array([1, 2, 3])
self.assertIsInstance(x.device_buffer, array.Array)
self.assertIsInstance(x.device_buffer, array.ArrayImpl)
@jtu.with_config(jax_array=True)
@ -539,8 +539,8 @@ class ShardingTest(jtu.JaxTestCase):
def test_is_subclass(self):
# array version of api_test.py::APITest::test_is_subclass
self.assertTrue(issubclass(array.Array, jnp.ndarray))
self.assertFalse(issubclass(array.Array, np.ndarray))
self.assertTrue(issubclass(array.ArrayImpl, jnp.ndarray))
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
if __name__ == '__main__':

View File

@ -26,7 +26,6 @@ import numpy as np
import jax
from jax._src import dtypes
from jax import numpy as jnp
from jax.experimental import array
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
@ -76,13 +75,6 @@ def identity(x):
return x
def _check_instance(self, x):
if config.jax_array:
self.assertIsInstance(x, array.Array)
else:
self.assertIsInstance(x, jnp.DeviceArray)
class DtypesTest(jtu.JaxTestCase):
def test_canonicalize_type(self):
@ -231,7 +223,7 @@ class DtypesTest(jtu.JaxTestCase):
def testScalarInstantiation(self, scalar_type):
a = scalar_type(1)
self.assertEqual(a.dtype, jnp.dtype(scalar_type))
_check_instance(self, a)
self.assertIsInstance(a, jax.Array)
self.assertEqual(0, jnp.ndim(a))
self.assertIsInstance(np.dtype(scalar_type).type(1), scalar_type)

View File

@ -34,7 +34,6 @@ from jax._src import test_util as jtu
from jax import tree_util
from jax._src.util import unzip2
from jax.experimental import maps
from jax.experimental import array
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
@ -2464,10 +2463,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
_, vjp_fun = jax.vjp(cumprod, x)
*_, ext_res = vjp_fun.args[0].args[0]
if config.jax_array:
self.assertIsInstance(ext_res, array.Array)
else:
self.assertIsInstance(ext_res, jnp.DeviceArray)
self.assertIsInstance(ext_res, jax.Array)
def test_scan_vmap_collectives(self):
def scan_f(state, x):

View File

@ -2783,7 +2783,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
out = jnp.concatenate([np_input])
if config.jax_array:
self.assertIs(type(out), array.Array)
self.assertIs(type(out), array.ArrayImpl)
else:
self.assertTrue(device_array.type_is_device_array(out))
@ -4154,8 +4154,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testArrayOutputsDeviceArrays(self):
if config.jax_array:
assert type(jnp.array([])) is array.Array
assert type(jnp.array(np.array([]))) is array.Array
assert type(jnp.array([])) is array.ArrayImpl
assert type(jnp.array(np.array([]))) is array.ArrayImpl
else:
assert device_array.type_is_device_array(jnp.array([]))
assert device_array.type_is_device_array(jnp.array(np.array([])))
@ -4164,7 +4164,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def __array__(self, dtype=None):
return np.array([], dtype=dtype)
if config.jax_array:
assert type(jnp.array(NDArrayLike())) is array.Array
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl
else:
assert device_array.type_is_device_array(jnp.array(NDArrayLike()))

View File

@ -1452,8 +1452,7 @@ class LaxTest(jtu.JaxTestCase):
if jit:
op = jax.jit(op)
result = op(operand)
expected_type = array.Array if config.jax_array else jnp.DeviceArray
self.assertIsInstance(result, expected_type)
self.assertIsInstance(result, jax.Array)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
@ -2915,10 +2914,7 @@ class LazyConstantTest(jtu.JaxTestCase):
if jit:
op = jax.jit(op)
result = op(input_type(value))
if config.jax_array:
assert isinstance(result, array.Array)
else:
assert isinstance(result, jnp.DeviceArray)
assert isinstance(result, jax.Array)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype_in={}_dtype_out={}".format(
@ -3212,13 +3208,13 @@ class FooArray:
ndim = property(lambda self: self.data.ndim - 1)
def device_put_foo_array(x: FooArray, device):
if isinstance(x.data, array.Array):
if isinstance(x.data, array.ArrayImpl):
return array._device_put_array(x.data, device)
return dispatch._device_put_array(x.data, device)
def shard_foo_array_handler(x, devices, indices, mode):
device, = devices
if isinstance(x.data, array.Array):
if isinstance(x.data, array.ArrayImpl):
return array._device_put_array(x.data, device)
return dispatch._device_put_array(x.data, device)

View File

@ -21,7 +21,6 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental import array
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge
@ -198,10 +197,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
devices = self.get_devices()
def f(): return lax.add(3., 4.)
if config.jax_array:
self.assertIsInstance(f(), array.Array)
else:
self.assertIsInstance(f(), jnp.DeviceArray)
self.assertIsInstance(f(), jax.Array)
self.assert_uncommitted_to_device(f(), devices[0])
self.assert_uncommitted_to_device(jax.jit(f)(), devices[0])
self.assert_committed_to_device(jax.jit(f, device=devices[1])(),

View File

@ -113,7 +113,7 @@ def simulated_cached_fun(s):
def _check_instance(self, x):
if config.jax_array:
self.assertIsInstance(x, array.Array)
self.assertIsInstance(x, array.ArrayImpl)
else:
self.assertIsInstance(x, pxla.ShardedDeviceArray)
@ -404,7 +404,7 @@ class PJitTest(jtu.BufferDonationTestCase):
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.Array)
self.assertIsInstance(actual, array.ArrayImpl)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False)
@ -433,7 +433,7 @@ class PJitTest(jtu.BufferDonationTestCase):
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.Array)
self.assertIsInstance(actual, array.ArrayImpl)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False)
@ -1539,7 +1539,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings]
out = compiled(*inputs)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out._value, input_data)
@ -1646,7 +1646,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
for ip in compiled.input_shardings]
out1, out2 = compiled(*inputs)
for o in [out1, out2]:
self.assertIsInstance(o, array.Array)
self.assertIsInstance(o, array.ArrayImpl)
self.assertArraysEqual(o._value, input_data)
@unittest.skip('The error is not raised yet. Enable this back once we raise '
@ -1696,7 +1696,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
expected_matrix_mul = input_data @ input_data.T
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
@ -1723,7 +1723,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
expected_matrix_mul = input_data @ input_data.T
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
@ -1744,7 +1744,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Since no in_axis_resources is provided, pjit will assume that
# the numpy input is fully replicated over the mesh.
out = f(input_data)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
@ -1763,7 +1763,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out_axis_resources=MeshPspecSharding(
global_mesh, P('x', 'y')))
out = f(input_data)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
@ -1773,7 +1773,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_unspecified_out_axis_resources(self):
def _checks(out, input_data):
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertIsInstance(out.sharding, OpShardingSharding)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
@ -1826,25 +1826,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
out_tree = f((a1, (a2, (a3, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
self.assertIsInstance(out1, array.Array)
self.assertIsInstance(out1, array.ArrayImpl)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out2, array.Array)
self.assertIsInstance(out2, array.ArrayImpl)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out3, array.Array)
self.assertIsInstance(out3, array.ArrayImpl)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out4, array.Array)
self.assertIsInstance(out4, array.ArrayImpl)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
for s in out4.addressable_shards:
@ -1879,7 +1879,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = pjit(
lambda x: x,
in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))(input_array)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
def test_no_input_output(self):
with jax_array(True):
@ -1918,7 +1918,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))
compiled = f.lower(aval, aval).compile()
out = compiled(a1, a1)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out._value, input_data @ input_data.T)
with self.assertRaisesRegex(
@ -2045,11 +2045,11 @@ class ArrayPjitTest(jtu.JaxTestCase):
def add(x, y):
return x + y
out = add(a, b)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out, a + b)
out2 = add(out, out)
self.assertIsInstance(out2, array.Array)
self.assertIsInstance(out2, array.ArrayImpl)
self.assertArraysEqual(out2, 2 * (a + b))
@jax_array(True)
@ -2061,7 +2061,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return x @ x.T
out = mul(a)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out, a @ a.T)
@jax_array(True)
@ -2107,7 +2107,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
a = jnp.array(16, dtype=jnp.float32)
f = lambda x: x
out = jax.grad(pjit(f))(a)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out, jax.grad(f)(a))
@jax_array(True)
@ -2137,7 +2137,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
# is correct.
bufs = [jax.device_put(inp_data[s.device_indices(d, shape)], d)
for d in jax.local_devices()]
arr = array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
arr = array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
f = pjit(lambda x: x, out_axis_resources=s)
out = f(arr)
@ -2225,7 +2225,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
arr = array.Array(
arr = array.ArrayImpl(
jax.ShapedArray(shape, np.int32), MeshPspecSharding(mesh, P(None)),
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
with self.assertRaisesRegex(

View File

@ -556,7 +556,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f_ans = f(x, y)
self.assertAllClose(f_ans, f_expected)
if config.jax_array:
self.assertIsInstance(f_ans, array.Array)
self.assertIsInstance(f_ans, array.ArrayImpl)
sharding_spec = f_ans.sharding.sharding_spec
else:
self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
@ -571,7 +571,7 @@ class PythonPmapTest(jtu.JaxTestCase):
g_ans = g(x, y)
self.assertAllClose(g_ans, g_expected)
if config.jax_array:
self.assertIsInstance(g_ans, array.Array)
self.assertIsInstance(g_ans, array.ArrayImpl)
sharding_spec = g_ans.sharding.sharding_spec
else:
self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
@ -722,7 +722,7 @@ class PythonPmapTest(jtu.JaxTestCase):
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
if config.jax_array:
self.assertIsInstance(y, array.Array)
self.assertIsInstance(y, array.ArrayImpl)
else:
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertIsInstance(y, device_array.DeviceArray)
@ -730,7 +730,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(y, 2 * x, check_dtypes=False)
z = f(y)
if config.jax_array:
self.assertIsInstance(z, array.Array)
self.assertIsInstance(z, array.ArrayImpl)
else:
self.assertIsInstance(z, pxla.ShardedDeviceArray)
self.assertIsInstance(z, device_array.DeviceArray)
@ -740,7 +740,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# test that we can pass in a regular DeviceArray
y = f(device_put(x))
if config.jax_array:
self.assertIsInstance(y, array.Array)
self.assertIsInstance(y, array.ArrayImpl)
else:
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertIsInstance(y, device_array.DeviceArray)
@ -1606,7 +1606,7 @@ class PythonPmapTest(jtu.JaxTestCase):
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
if config.jax_array:
self.assertIsInstance(y, array.Array)
self.assertIsInstance(y, array.ArrayImpl)
else:
self.assertIsInstance(y, pxla.ShardedDeviceArray)
@ -2585,7 +2585,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
self.assertIsNone(sharded_x._npy_value)
if config.jax_array:
arr_type = array.Array
arr_type = array.ArrayImpl
else:
arr_type = device_array.DeviceArray
@ -2595,7 +2595,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
('array', True, array.ArrayImpl, '_arrays')
)
def test_device_put_sharded(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
@ -2611,7 +2611,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
('array', True, array.ArrayImpl, '_arrays')
)
def test_device_put_sharded_pytree(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
@ -2632,7 +2632,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
('array', True, array.ArrayImpl, '_arrays')
)
def test_device_put_replicated(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
@ -2648,7 +2648,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
('array', True, array.ArrayImpl, '_arrays')
)
def test_device_put_replicated_pytree(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
@ -2900,7 +2900,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
expected = input_data * input_data
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected[s.index])
self.assertArraysEqual(out, expected)
@ -2918,8 +2918,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
with jax_config.jax_array(True):
out1, out2 = f(input_array, input_array)
self.assertIsInstance(out1, array.Array)
self.assertIsInstance(out2, array.Array)
self.assertIsInstance(out1, array.ArrayImpl)
self.assertIsInstance(out2, array.ArrayImpl)
for s1, s2 in safe_zip(out1.addressable_shards, out2.addressable_shards):
self.assertArraysEqual(s1.data._arrays[0], input_data[s1.index])
self.assertArraysEqual(s2.data._arrays[0], input_data[s2.index])
@ -2942,8 +2942,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
with jax_config.jax_array(True):
out1, out2 = f(a1, a2)
self.assertIsInstance(out1, array.Array)
self.assertIsInstance(out2, array.Array)
self.assertIsInstance(out1, array.ArrayImpl)
self.assertIsInstance(out2, array.ArrayImpl)
self.assertEqual(out1.shape, (2,))
self.assertEqual(out2.shape, (dc, dc, 2))
for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)):

View File

@ -27,7 +27,7 @@ from jax._src import typing
from jax import lax
import jax.numpy as jnp
from jax.experimental.array import Array as ArrayImpl
from jax.experimental.array import ArrayImpl
from absl.testing import absltest
import numpy as np

View File

@ -1167,7 +1167,7 @@ class XMapArrayTest(XMapTestCase):
axis_resources={"a": "x", "b": "y"})
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out, array.ArrayImpl)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
self.assertDictEqual(out.sharding.mesh.shape, {'x': 4, 'y': 2})
@ -1191,14 +1191,14 @@ class XMapArrayTest(XMapTestCase):
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
out1, out2 = f(a1, a2)
self.assertIsInstance(out1, array.Array)
self.assertIsInstance(out1, array.ArrayImpl)
self.assertEqual(out1.shape, (8,))
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
self.assertDictEqual(out1.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertIsInstance(out2, array.Array)
self.assertIsInstance(out2, array.ArrayImpl)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
self.assertDictEqual(out2.sharding.mesh.shape, {'x': 4, 'y': 2})