mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant. PiperOrigin-RevId: 469984029
This commit is contained in:
parent
fd3a72dd1f
commit
5527966b27
@ -22,6 +22,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Deprecations:
|
||||
* The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile`
|
||||
({jax-issue}`#11944`).
|
||||
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
|
||||
|
||||
## jax 0.3.16
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
||||
|
@ -2066,7 +2066,7 @@
|
||||
"\n",
|
||||
"def handle_result(aval: ShapedArray, buf):\n",
|
||||
" del aval # Unused for now\n",
|
||||
" return buf.to_py()\n",
|
||||
" return np.asarray(buf)\n",
|
||||
"\n",
|
||||
"xla_translations = {}"
|
||||
]
|
||||
@ -2370,9 +2370,9 @@
|
||||
" shape = property(lambda self: self.aval.shape)\n",
|
||||
" ndim = property(lambda self: self.aval.ndim)\n",
|
||||
"\n",
|
||||
" def __array__(self): return self.buf.to_py()\n",
|
||||
" def __repr__(self): return repr(self.buf.to_py())\n",
|
||||
" def __str__(self): return str(self.buf.to_py())\n",
|
||||
" def __array__(self): return np.asarray(self.buf)\n",
|
||||
" def __repr__(self): return repr(np.asarray(self.buf))\n",
|
||||
" def __str__(self): return str(np.asarray(self.buf))\n",
|
||||
"\n",
|
||||
" _neg = staticmethod(neg)\n",
|
||||
" _add = staticmethod(add)\n",
|
||||
|
@ -1626,7 +1626,7 @@ input_handlers = {ty: default_input_handler for ty in
|
||||
|
||||
def handle_result(aval: ShapedArray, buf):
|
||||
del aval # Unused for now
|
||||
return buf.to_py()
|
||||
return np.asarray(buf)
|
||||
|
||||
xla_translations = {}
|
||||
```
|
||||
@ -1842,9 +1842,9 @@ class DeviceArray:
|
||||
shape = property(lambda self: self.aval.shape)
|
||||
ndim = property(lambda self: self.aval.ndim)
|
||||
|
||||
def __array__(self): return self.buf.to_py()
|
||||
def __repr__(self): return repr(self.buf.to_py())
|
||||
def __str__(self): return str(self.buf.to_py())
|
||||
def __array__(self): return np.asarray(self.buf)
|
||||
def __repr__(self): return repr(np.asarray(self.buf))
|
||||
def __str__(self): return str(np.asarray(self.buf))
|
||||
|
||||
_neg = staticmethod(neg)
|
||||
_add = staticmethod(add)
|
||||
|
@ -1620,7 +1620,7 @@ input_handlers = {ty: default_input_handler for ty in
|
||||
|
||||
def handle_result(aval: ShapedArray, buf):
|
||||
del aval # Unused for now
|
||||
return buf.to_py()
|
||||
return np.asarray(buf)
|
||||
|
||||
xla_translations = {}
|
||||
|
||||
@ -1833,9 +1833,9 @@ class DeviceArray:
|
||||
shape = property(lambda self: self.aval.shape)
|
||||
ndim = property(lambda self: self.aval.ndim)
|
||||
|
||||
def __array__(self): return self.buf.to_py()
|
||||
def __repr__(self): return repr(self.buf.to_py())
|
||||
def __str__(self): return str(self.buf.to_py())
|
||||
def __array__(self): return np.asarray(self.buf)
|
||||
def __repr__(self): return repr(np.asarray(self.buf))
|
||||
def __str__(self): return str(np.asarray(self.buf))
|
||||
|
||||
_neg = staticmethod(neg)
|
||||
_add = staticmethod(add)
|
||||
|
@ -18,6 +18,7 @@ from functools import partial, partialmethod
|
||||
import operator
|
||||
from typing import (Any, List, Optional, Union)
|
||||
import weakref
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -146,7 +147,7 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
def _value(self):
|
||||
self._check_if_deleted()
|
||||
if self._npy_value is None:
|
||||
self._npy_value = self.device_buffer.to_py() # pytype: disable=attribute-error # bind-properties
|
||||
self._npy_value = np.asarray(self.device_buffer) # pytype: disable=attribute-error # bind-properties
|
||||
self._npy_value.flags.writeable = False
|
||||
return self._npy_value
|
||||
|
||||
@ -266,6 +267,15 @@ for device_array in [DeviceArray]:
|
||||
|
||||
setattr(device_array, "__array__", __array__)
|
||||
|
||||
# TODO(phawkins): delete this code path after the deprecation for .to_py()
|
||||
# expires in Nov 2022.
|
||||
def to_py(self):
|
||||
warnings.warn("The .to_py() method on JAX arrays is deprecated. Use "
|
||||
"np.asarray(...) instead.", category=FutureWarning)
|
||||
return np.asarray(self._value)
|
||||
|
||||
setattr(device_array, "to_py", to_py)
|
||||
|
||||
def __dlpack__(self):
|
||||
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
||||
return to_dlpack(self)
|
||||
|
@ -779,9 +779,9 @@ def check_special(name, bufs):
|
||||
def _check_special(name, xla_shape, buf):
|
||||
assert not xla_shape.is_tuple()
|
||||
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
|
||||
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
|
||||
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
|
||||
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
||||
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
|
||||
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
||||
@ -1155,7 +1155,7 @@ def _copy_device_array_to_device(
|
||||
else:
|
||||
# buffers from different XLA backends are passed through the host.
|
||||
backend = xb.get_device_backend(device)
|
||||
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
|
||||
moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device)
|
||||
return device_array.make_device_array(x.aval, device, moved_buf)
|
||||
|
||||
|
||||
@ -1182,7 +1182,7 @@ def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
|
||||
else:
|
||||
# buffers from different XLA backends are passed through the host.
|
||||
backend = xb.get_device_backend(device)
|
||||
moved_buf = backend.buffer_from_pyval(buf.to_py(), device)
|
||||
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
|
||||
return array.Array(
|
||||
x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf],
|
||||
committed=(device is not None))
|
||||
|
@ -88,7 +88,7 @@ class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
def copy_to_device(self, device):
|
||||
return self
|
||||
|
||||
def to_py(self) -> np.ndarray:
|
||||
def __array__(self, dtype=None, context=None):
|
||||
return np.asarray(self._buffer)
|
||||
|
||||
def to_iree(self):
|
||||
@ -104,8 +104,11 @@ class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
return self # no async
|
||||
|
||||
# overrides repr on base class which expects _value and aval attributes
|
||||
def __repr__(self): return f'IreeBuffer({self.to_py()})'
|
||||
_value = property(to_py)
|
||||
def __repr__(self): return f'IreeBuffer({np.asarray(self)})'
|
||||
|
||||
@property
|
||||
def _value(self):
|
||||
return np.asarray(self)
|
||||
|
||||
class IreeExecutable:
|
||||
|
||||
|
@ -290,9 +290,6 @@ class Array:
|
||||
self._check_if_deleted()
|
||||
return list(self.sharding.device_set)
|
||||
|
||||
def to_py(self) -> np.ndarray:
|
||||
return self._value
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def addressable_shards(self) -> Sequence[Shard]:
|
||||
self._check_if_deleted()
|
||||
@ -367,7 +364,7 @@ class Array:
|
||||
|
||||
for s in self.addressable_shards:
|
||||
if not replica_id_exists or s.replica_id == 0:
|
||||
npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr]
|
||||
npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr]
|
||||
self._npy_value = npy_value # type: ignore
|
||||
# https://docs.python.org/3/library/typing.html#typing.cast
|
||||
return cast(np.ndarray, self._npy_value)
|
||||
|
@ -72,16 +72,16 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
[mesh_axes, P('x'), P(None)],
|
||||
tspecs)
|
||||
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m1.local_shards[0].data),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m1.local_shards[1].data),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertArraysEqual(m2.local_shards[0].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m2.local_shards[0].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertArraysEqual(m2.local_shards[1].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m2.local_shards[1].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
@ -89,7 +89,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for i, s in enumerate(m3.local_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
@ -132,17 +132,17 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
tspecs)
|
||||
|
||||
self.assertIsInstance(m1, array.Array)
|
||||
self.assertArraysEqual(m1.addressable_shards[0].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(m1.addressable_shards[1].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertIsInstance(m2, array.Array)
|
||||
self.assertArraysEqual(m2.addressable_shards[0].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertArraysEqual(m2.addressable_shards[1].data.to_py(),
|
||||
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
@ -151,7 +151,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for i, s in enumerate(m3.addressable_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_bigger_shape(self):
|
||||
@ -192,7 +192,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
}
|
||||
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
def test_checkpointing_scalar(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
@ -216,7 +216,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), data.astype(np.float32))
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
|
@ -386,9 +386,12 @@ class GlobalDeviceArray:
|
||||
for s in self.local_shards if s.replica_id == 0]
|
||||
npy_value = np.empty(self.shape, self.dtype)
|
||||
for s in unique_shards:
|
||||
npy_value[s.index] = s.data.to_py()
|
||||
npy_value[s.index] = np.asarray(s.data)
|
||||
return npy_value
|
||||
|
||||
def __array__(self, dtype=None, context=None):
|
||||
return self._value if dtype is None else self._value.astype(dtype)
|
||||
|
||||
def local_data(self, index) -> DeviceArray:
|
||||
return pxla._set_aval(self._device_buffers[index])
|
||||
|
||||
|
@ -103,7 +103,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
||||
def _pjit(inp):
|
||||
if isinstance(inp, GlobalDeviceArray):
|
||||
if inp.is_fully_replicated:
|
||||
return inp.local_data(0).to_py()
|
||||
return np.asarray(inp.local_data(0))
|
||||
global_mesh = inp.mesh
|
||||
in_axis_resources = FROM_GDA
|
||||
else:
|
||||
@ -119,7 +119,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=None)(inp)
|
||||
return out.local_data(0).to_py()
|
||||
return np.asarray(out.local_data(0))
|
||||
|
||||
with config_internal.parallel_functions_output_gda(True):
|
||||
return jax.tree_util.tree_map(_pjit, in_tree)
|
||||
|
@ -303,7 +303,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
||||
|
||||
def _device_array_constant_handler(val, canonicalize_types):
|
||||
return _ndarray_constant_handler(val.device_buffer.to_py(),
|
||||
return _ndarray_constant_handler(np.asarray(val.device_buffer),
|
||||
canonicalize_types)
|
||||
for t in device_array.device_array_types:
|
||||
register_constant_handler(t, _device_array_constant_handler)
|
||||
|
@ -699,7 +699,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
|
||||
precomputed for efficiency. A list the same length as
|
||||
`device_buffers`. Each index indicates what portion of the full array is
|
||||
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
|
||||
device_buffers[i].to_py()`.
|
||||
np.asarray(device_buffers[i])`.
|
||||
"""
|
||||
__slots__ = [
|
||||
"aval", "device_buffers", "sharding_spec", "indices",
|
||||
@ -792,7 +792,7 @@ def _sda_value(self):
|
||||
self.copy_to_host_async()
|
||||
npy_value = np.empty(self.aval.shape, self.aval.dtype)
|
||||
for i in self.one_replica_buffer_indices:
|
||||
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
|
||||
npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
|
||||
self._npy_value = npy_value
|
||||
return self._npy_value
|
||||
|
||||
|
@ -194,7 +194,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
for i, s in enumerate(gda.local_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(gda.dtype, np.float32)
|
||||
self.assertEqual(
|
||||
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
|
||||
@ -249,10 +249,10 @@ class GDATest(jtu.JaxTestCase):
|
||||
gda = GlobalDeviceArray.from_batched_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1]])
|
||||
self.assertArraysEqual(gda.local_data(0).to_py(),
|
||||
self.assertArraysEqual(np.asarray(gda.local_data(0)),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[2, 3]])
|
||||
self.assertArraysEqual(gda.local_data(1).to_py(),
|
||||
self.assertArraysEqual(np.asarray(gda.local_data(1)),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gda_batched_callback_with_devices(self):
|
||||
@ -275,10 +275,10 @@ class GDATest(jtu.JaxTestCase):
|
||||
gda = GlobalDeviceArray.from_batched_callback_with_devices(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
self.assertArraysEqual(gda.local_data(0).to_py(),
|
||||
self.assertArraysEqual(np.asarray(gda.local_data(0)),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
self.assertArraysEqual(gda.local_data(1).to_py(),
|
||||
self.assertArraysEqual(np.asarray(gda.local_data(1)),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gda_str_repr(self):
|
||||
|
@ -113,7 +113,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
complex_type = dtypes.canonicalize_dtype(np.complex128)
|
||||
|
||||
# int
|
||||
res = _cpp_device_put(1, device).to_py()
|
||||
res = np.asarray(_cpp_device_put(1, device))
|
||||
self.assertEqual(res, 1)
|
||||
self.assertEqual(res.dtype, int_type)
|
||||
# We also compare to the Python Jax API, to make sure we have the exact
|
||||
@ -122,20 +122,20 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jnp.asarray(1).dtype, res.dtype)
|
||||
|
||||
# float
|
||||
res = _cpp_device_put(1.0, device).to_py()
|
||||
res = np.asarray(_cpp_device_put(1.0, device))
|
||||
self.assertEqual(res, 1.0)
|
||||
self.assertEqual(res.dtype, float_type)
|
||||
self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)
|
||||
|
||||
# bool
|
||||
for bool_value in [True, False]:
|
||||
res = _cpp_device_put(bool_value, device).to_py()
|
||||
res = np.asarray(_cpp_device_put(bool_value, device))
|
||||
self.assertEqual(res, np.asarray(bool_value))
|
||||
self.assertEqual(res.dtype, np.bool_)
|
||||
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
|
||||
|
||||
# Complex
|
||||
res = _cpp_device_put(1 + 1j, device).to_py()
|
||||
res = np.asarray(_cpp_device_put(1 + 1j, device))
|
||||
self.assertEqual(res, 1 + 1j)
|
||||
self.assertEqual(res.dtype, complex_type)
|
||||
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
|
||||
|
@ -135,7 +135,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
|
||||
self.assertLen(actual.device_buffers, 1)
|
||||
self.assertAllClose(
|
||||
actual.device_buffers[0].to_py(), expected, check_dtypes=False)
|
||||
np.asarray(actual.device_buffers[0]), expected, check_dtypes=False)
|
||||
# Repro for a bug on device_buffer aval
|
||||
_ = repr(actual.device_buffers)
|
||||
|
||||
@ -154,7 +154,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
|
||||
self.assertLen(actual.device_buffers, 2)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
|
||||
check_dtypes=False)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
@ -191,7 +191,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
|
||||
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
|
||||
self.assertLen(actual.device_buffers, 2)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py()[:3], expected[:3],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0])[:3], expected[:3],
|
||||
check_dtypes=False)
|
||||
|
||||
def testBasic1DWithMeshContextManager(self):
|
||||
@ -210,7 +210,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
|
||||
self.assertLen(actual.device_buffers, 2)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
|
||||
check_dtypes=False)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@ -232,13 +232,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertLen(actual.device_buffers, 4)
|
||||
|
||||
split0, split1 = np.split(expected, 2)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), split0,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), split0,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[1].to_py(), split0,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[1]), split0,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[2].to_py(), split1,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[2]), split1,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[3]), split1,
|
||||
check_dtypes=False)
|
||||
|
||||
def testBasic2DWithMeshContextManager(self):
|
||||
@ -261,13 +261,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertLen(actual.device_buffers, 4)
|
||||
|
||||
split0, split1 = np.split(expected, 2)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), split0,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), split0,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[1].to_py(), split0,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[1]), split0,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[2].to_py(), split1,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[2]), split1,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[3]), split1,
|
||||
check_dtypes=False)
|
||||
|
||||
def testDifferentNestedMesh(self):
|
||||
@ -318,13 +318,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertLen(actual.device_buffers, 4)
|
||||
|
||||
splits = np.split(expected, 4)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1],
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2],
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3],
|
||||
check_dtypes=False)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
@ -363,7 +363,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
|
||||
self.assertLen(actual.device_buffers, 2)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
|
||||
check_dtypes=False)
|
||||
|
||||
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
||||
@ -390,7 +390,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, array.Array)
|
||||
self.assertLen(actual.addressable_shards, 2)
|
||||
self.assertAllClose(actual._arrays[0].to_py(), expected,
|
||||
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
|
||||
check_dtypes=False)
|
||||
|
||||
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
||||
@ -419,7 +419,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, array.Array)
|
||||
self.assertLen(actual.addressable_shards, 2)
|
||||
self.assertAllClose(actual._arrays[0].to_py(), expected,
|
||||
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
|
||||
check_dtypes=False)
|
||||
|
||||
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
||||
@ -842,13 +842,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
|
||||
|
||||
splits = np.split(expected, 4)
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1],
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2],
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3],
|
||||
check_dtypes=False)
|
||||
|
||||
for obj in [lowered, compiled]:
|
||||
|
@ -3016,7 +3016,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
self.assertEqual(len(bufs), 1)
|
||||
self.assertEqual(len(bufs[0]), nshards)
|
||||
for buf, idx in zip(bufs[0], indices):
|
||||
self.assertAllClose(buf.to_py(), x[idx], check_dtypes=False)
|
||||
self.assertAllClose(np.asarray(buf), x[idx], check_dtypes=False)
|
||||
|
||||
|
||||
class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user