From 5527966b27b0e6899db802a7a8f1121ccc8f2b11 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 25 Aug 2022 07:27:54 -0700 Subject: [PATCH] [JAX] Deprecate .to_py() property on arrays. Implement __array__ instead. .to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant. PiperOrigin-RevId: 469984029 --- CHANGELOG.md | 1 + docs/autodidax.ipynb | 8 ++-- docs/autodidax.md | 8 ++-- docs/autodidax.py | 8 ++-- jax/_src/device_array.py | 12 ++++- jax/_src/dispatch.py | 8 ++-- jax/_src/iree.py | 9 ++-- jax/experimental/array.py | 5 +- .../gda_serialization/serialization_test.py | 24 +++++----- jax/experimental/global_device_array.py | 5 +- jax/experimental/multihost_utils.py | 4 +- jax/interpreters/mlir.py | 2 +- jax/interpreters/pxla.py | 4 +- tests/global_device_array_test.py | 10 ++-- tests/jax_jit_test.py | 8 ++-- tests/pjit_test.py | 46 +++++++++---------- tests/pmap_test.py | 2 +- 17 files changed, 89 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bf6a3fd1..eee2336e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Deprecations: * The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile` ({jax-issue}`#11944`). + * `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead. ## jax 0.3.16 * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 097bbed2c..7a26ce598 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2066,7 +2066,7 @@ "\n", "def handle_result(aval: ShapedArray, buf):\n", " del aval # Unused for now\n", - " return buf.to_py()\n", + " return np.asarray(buf)\n", "\n", "xla_translations = {}" ] @@ -2370,9 +2370,9 @@ " shape = property(lambda self: self.aval.shape)\n", " ndim = property(lambda self: self.aval.ndim)\n", "\n", - " def __array__(self): return self.buf.to_py()\n", - " def __repr__(self): return repr(self.buf.to_py())\n", - " def __str__(self): return str(self.buf.to_py())\n", + " def __array__(self): return np.asarray(self.buf)\n", + " def __repr__(self): return repr(np.asarray(self.buf))\n", + " def __str__(self): return str(np.asarray(self.buf))\n", "\n", " _neg = staticmethod(neg)\n", " _add = staticmethod(add)\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index ff924595b..8fe47592e 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1626,7 +1626,7 @@ input_handlers = {ty: default_input_handler for ty in def handle_result(aval: ShapedArray, buf): del aval # Unused for now - return buf.to_py() + return np.asarray(buf) xla_translations = {} ``` @@ -1842,9 +1842,9 @@ class DeviceArray: shape = property(lambda self: self.aval.shape) ndim = property(lambda self: self.aval.ndim) - def __array__(self): return self.buf.to_py() - def __repr__(self): return repr(self.buf.to_py()) - def __str__(self): return str(self.buf.to_py()) + def __array__(self): return np.asarray(self.buf) + def __repr__(self): return repr(np.asarray(self.buf)) + def __str__(self): return str(np.asarray(self.buf)) _neg = staticmethod(neg) _add = staticmethod(add) diff --git a/docs/autodidax.py b/docs/autodidax.py index 802564dd4..7fa84f82d 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1620,7 +1620,7 @@ input_handlers = {ty: default_input_handler for ty in def handle_result(aval: ShapedArray, buf): del aval # Unused for now - return buf.to_py() + return np.asarray(buf) xla_translations = {} @@ -1833,9 +1833,9 @@ class DeviceArray: shape = property(lambda self: self.aval.shape) ndim = property(lambda self: self.aval.ndim) - def __array__(self): return self.buf.to_py() - def __repr__(self): return repr(self.buf.to_py()) - def __str__(self): return str(self.buf.to_py()) + def __array__(self): return np.asarray(self.buf) + def __repr__(self): return repr(np.asarray(self.buf)) + def __str__(self): return str(np.asarray(self.buf)) _neg = staticmethod(neg) _add = staticmethod(add) diff --git a/jax/_src/device_array.py b/jax/_src/device_array.py index ed31feb47..58982416c 100644 --- a/jax/_src/device_array.py +++ b/jax/_src/device_array.py @@ -18,6 +18,7 @@ from functools import partial, partialmethod import operator from typing import (Any, List, Optional, Union) import weakref +import warnings import numpy as np @@ -146,7 +147,7 @@ class _DeviceArray(DeviceArray): # type: ignore def _value(self): self._check_if_deleted() if self._npy_value is None: - self._npy_value = self.device_buffer.to_py() # pytype: disable=attribute-error # bind-properties + self._npy_value = np.asarray(self.device_buffer) # pytype: disable=attribute-error # bind-properties self._npy_value.flags.writeable = False return self._npy_value @@ -266,6 +267,15 @@ for device_array in [DeviceArray]: setattr(device_array, "__array__", __array__) + # TODO(phawkins): delete this code path after the deprecation for .to_py() + # expires in Nov 2022. + def to_py(self): + warnings.warn("The .to_py() method on JAX arrays is deprecated. Use " + "np.asarray(...) instead.", category=FutureWarning) + return np.asarray(self._value) + + setattr(device_array, "to_py", to_py) + def __dlpack__(self): from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top return to_dlpack(self) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e470b7ad5..a32aa6bf1 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -779,9 +779,9 @@ def check_special(name, bufs): def _check_special(name, xla_shape, buf): assert not xla_shape.is_tuple() if dtypes.issubdtype(xla_shape.element_type(), np.inexact): - if config.jax_debug_nans and np.any(np.isnan(buf.to_py())): + if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))): raise FloatingPointError(f"invalid value (nan) encountered in {name}") - if config.jax_debug_infs and np.any(np.isinf(buf.to_py())): + if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))): raise FloatingPointError(f"invalid value (inf) encountered in {name}") def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], @@ -1155,7 +1155,7 @@ def _copy_device_array_to_device( else: # buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) - moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device) + moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device) return device_array.make_device_array(x.aval, device, moved_buf) @@ -1182,7 +1182,7 @@ def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array: else: # buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) - moved_buf = backend.buffer_from_pyval(buf.to_py(), device) + moved_buf = backend.buffer_from_pyval(np.asarray(buf), device) return array.Array( x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf], committed=(device is not None)) diff --git a/jax/_src/iree.py b/jax/_src/iree.py index 288581a9c..22e6d4144 100644 --- a/jax/_src/iree.py +++ b/jax/_src/iree.py @@ -88,7 +88,7 @@ class IreeBuffer(xla_client.DeviceArrayBase): def copy_to_device(self, device): return self - def to_py(self) -> np.ndarray: + def __array__(self, dtype=None, context=None): return np.asarray(self._buffer) def to_iree(self): @@ -104,8 +104,11 @@ class IreeBuffer(xla_client.DeviceArrayBase): return self # no async # overrides repr on base class which expects _value and aval attributes - def __repr__(self): return f'IreeBuffer({self.to_py()})' - _value = property(to_py) + def __repr__(self): return f'IreeBuffer({np.asarray(self)})' + + @property + def _value(self): + return np.asarray(self) class IreeExecutable: diff --git a/jax/experimental/array.py b/jax/experimental/array.py index a165d8695..091a4df5d 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -290,9 +290,6 @@ class Array: self._check_if_deleted() return list(self.sharding.device_set) - def to_py(self) -> np.ndarray: - return self._value - @pxla.maybe_cached_property def addressable_shards(self) -> Sequence[Shard]: self._check_if_deleted() @@ -367,7 +364,7 @@ class Array: for s in self.addressable_shards: if not replica_id_exists or s.replica_id == 0: - npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr] + npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr] self._npy_value = npy_value # type: ignore # https://docs.python.org/3/library/typing.html#typing.cast return cast(np.ndarray, self._npy_value) diff --git a/jax/experimental/gda_serialization/serialization_test.py b/jax/experimental/gda_serialization/serialization_test.py index aed78891d..5973df308 100644 --- a/jax/experimental/gda_serialization/serialization_test.py +++ b/jax/experimental/gda_serialization/serialization_test.py @@ -72,16 +72,16 @@ class CheckpointTest(jtu.JaxTestCase): [mesh_axes, P('x'), P(None)], tspecs) - self.assertArraysEqual(m1.local_shards[0].data.to_py(), + self.assertArraysEqual(np.asarray(m1.local_shards[0].data), np.array([[0], [2]])) - self.assertArraysEqual(m1.local_shards[1].data.to_py(), + self.assertArraysEqual(np.asarray(m1.local_shards[1].data), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) - self.assertArraysEqual(m2.local_shards[0].data.to_py(), + self.assertArraysEqual(np.asarray(m2.local_shards[0].data), np.array([[16, 17], [18, 19]])) - self.assertArraysEqual(m2.local_shards[1].data.to_py(), + self.assertArraysEqual(np.asarray(m2.local_shards[1].data), np.array([[16, 17], [18, 19]])) self.assertEqual(m2.local_shards[0].data.shape, (2, 2)) self.assertEqual(m2.dtype, np.int32) @@ -89,7 +89,7 @@ class CheckpointTest(jtu.JaxTestCase): for i, s in enumerate(m3.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) - self.assertArraysEqual(s.data.to_py(), np.array([])) + self.assertArraysEqual(np.asarray(s.data), np.array([])) self.assertEqual(m3.dtype, np.float32) @jax_config.jax_array(True) @@ -132,17 +132,17 @@ class CheckpointTest(jtu.JaxTestCase): tspecs) self.assertIsInstance(m1, array.Array) - self.assertArraysEqual(m1.addressable_shards[0].data.to_py(), + self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data), np.array([[0], [2]])) - self.assertArraysEqual(m1.addressable_shards[1].data.to_py(), + self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data), np.array([[1], [3]])) self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) self.assertIsInstance(m2, array.Array) - self.assertArraysEqual(m2.addressable_shards[0].data.to_py(), + self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data), np.array([[16, 17], [18, 19]])) - self.assertArraysEqual(m2.addressable_shards[1].data.to_py(), + self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data), np.array([[16, 17], [18, 19]])) self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2)) self.assertEqual(m2.dtype, np.int32) @@ -151,7 +151,7 @@ class CheckpointTest(jtu.JaxTestCase): for i, s in enumerate(m3.addressable_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) - self.assertArraysEqual(s.data.to_py(), np.array([])) + self.assertArraysEqual(np.asarray(s.data), np.array([])) self.assertEqual(m3.dtype, np.float32) def test_checkpointing_with_bigger_shape(self): @@ -192,7 +192,7 @@ class CheckpointTest(jtu.JaxTestCase): } for l in m1.local_shards: - self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id]) + self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) def test_checkpointing_scalar(self): global_mesh = jtu.create_global_mesh((2,), ('x')) @@ -216,7 +216,7 @@ class CheckpointTest(jtu.JaxTestCase): ) for l in m1.local_shards: - self.assertArraysEqual(l.data.to_py(), data.astype(np.float32)) + self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) def test_spec_has_metadata(self): spec = { diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index d39ef3359..65a0b0bab 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -386,9 +386,12 @@ class GlobalDeviceArray: for s in self.local_shards if s.replica_id == 0] npy_value = np.empty(self.shape, self.dtype) for s in unique_shards: - npy_value[s.index] = s.data.to_py() + npy_value[s.index] = np.asarray(s.data) return npy_value + def __array__(self, dtype=None, context=None): + return self._value if dtype is None else self._value.astype(dtype) + def local_data(self, index) -> DeviceArray: return pxla._set_aval(self._device_buffers[index]) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index f36225590..75665a083 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -103,7 +103,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef: def _pjit(inp): if isinstance(inp, GlobalDeviceArray): if inp.is_fully_replicated: - return inp.local_data(0).to_py() + return np.asarray(inp.local_data(0)) global_mesh = inp.mesh in_axis_resources = FROM_GDA else: @@ -119,7 +119,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef: with maps.Mesh(global_mesh.devices, global_mesh.axis_names): out = pjit(lambda x: x, in_axis_resources=in_axis_resources, out_axis_resources=None)(inp) - return out.local_data(0).to_py() + return np.asarray(out.local_data(0)) with config_internal.parallel_functions_output_gda(True): return jax.tree_util.tree_map(_pjit, in_tree) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index f88741dfa..9d7da999d 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -303,7 +303,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items(): register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) def _device_array_constant_handler(val, canonicalize_types): - return _ndarray_constant_handler(val.device_buffer.to_py(), + return _ndarray_constant_handler(np.asarray(val.device_buffer), canonicalize_types) for t in device_array.device_array_types: register_constant_handler(t, _device_array_constant_handler) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index c07cb87b4..fb411c082 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -699,7 +699,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore precomputed for efficiency. A list the same length as `device_buffers`. Each index indicates what portion of the full array is stored in the corresponding device buffer, i.e. `array[indices[i]] == - device_buffers[i].to_py()`. + np.asarray(device_buffers[i])`. """ __slots__ = [ "aval", "device_buffers", "sharding_spec", "indices", @@ -792,7 +792,7 @@ def _sda_value(self): self.copy_to_host_async() npy_value = np.empty(self.aval.shape, self.aval.dtype) for i in self.one_replica_buffer_indices: - npy_value[self.indices[i]] = self.device_buffers[i].to_py() + npy_value[self.indices[i]] = np.asarray(self.device_buffers[i]) self._npy_value = npy_value return self._npy_value diff --git a/tests/global_device_array_test.py b/tests/global_device_array_test.py index c7540c2b5..d862e23cf 100644 --- a/tests/global_device_array_test.py +++ b/tests/global_device_array_test.py @@ -194,7 +194,7 @@ class GDATest(jtu.JaxTestCase): for i, s in enumerate(gda.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) - self.assertArraysEqual(s.data.to_py(), np.array([])) + self.assertArraysEqual(np.asarray(s.data), np.array([])) self.assertEqual(gda.dtype, np.float32) self.assertEqual( gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes), @@ -249,10 +249,10 @@ class GDATest(jtu.JaxTestCase): gda = GlobalDeviceArray.from_batched_callback( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1]]) - self.assertArraysEqual(gda.local_data(0).to_py(), + self.assertArraysEqual(np.asarray(gda.local_data(0)), expected_first_shard_value) expected_second_shard_value = np.array([[2, 3]]) - self.assertArraysEqual(gda.local_data(1).to_py(), + self.assertArraysEqual(np.asarray(gda.local_data(1)), expected_second_shard_value) def test_gda_batched_callback_with_devices(self): @@ -275,10 +275,10 @@ class GDATest(jtu.JaxTestCase): gda = GlobalDeviceArray.from_batched_callback_with_devices( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) - self.assertArraysEqual(gda.local_data(0).to_py(), + self.assertArraysEqual(np.asarray(gda.local_data(0)), expected_first_shard_value) expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) - self.assertArraysEqual(gda.local_data(1).to_py(), + self.assertArraysEqual(np.asarray(gda.local_data(1)), expected_second_shard_value) def test_gda_str_repr(self): diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index f359857fe..2eb769adc 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -113,7 +113,7 @@ class JaxJitTest(jtu.JaxTestCase): complex_type = dtypes.canonicalize_dtype(np.complex128) # int - res = _cpp_device_put(1, device).to_py() + res = np.asarray(_cpp_device_put(1, device)) self.assertEqual(res, 1) self.assertEqual(res.dtype, int_type) # We also compare to the Python Jax API, to make sure we have the exact @@ -122,20 +122,20 @@ class JaxJitTest(jtu.JaxTestCase): self.assertEqual(jnp.asarray(1).dtype, res.dtype) # float - res = _cpp_device_put(1.0, device).to_py() + res = np.asarray(_cpp_device_put(1.0, device)) self.assertEqual(res, 1.0) self.assertEqual(res.dtype, float_type) self.assertEqual(jnp.asarray(1.0).dtype, res.dtype) # bool for bool_value in [True, False]: - res = _cpp_device_put(bool_value, device).to_py() + res = np.asarray(_cpp_device_put(bool_value, device)) self.assertEqual(res, np.asarray(bool_value)) self.assertEqual(res.dtype, np.bool_) self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype) # Complex - res = _cpp_device_put(1 + 1j, device).to_py() + res = np.asarray(_cpp_device_put(1 + 1j, device)) self.assertEqual(res, 1 + 1j) self.assertEqual(res.dtype, complex_type) self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 063ac9453..c8d392c15 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -135,7 +135,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 1) self.assertAllClose( - actual.device_buffers[0].to_py(), expected, check_dtypes=False) + np.asarray(actual.device_buffers[0]), expected, check_dtypes=False) # Repro for a bug on device_buffer aval _ = repr(actual.device_buffers) @@ -154,7 +154,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) - self.assertAllClose(actual.device_buffers[0].to_py(), expected, + self.assertAllClose(np.asarray(actual.device_buffers[0]), expected, check_dtypes=False) @jtu.with_mesh([('x', 2)]) @@ -191,7 +191,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual[:3], expected[:3], check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) - self.assertAllClose(actual.device_buffers[0].to_py()[:3], expected[:3], + self.assertAllClose(np.asarray(actual.device_buffers[0])[:3], expected[:3], check_dtypes=False) def testBasic1DWithMeshContextManager(self): @@ -210,7 +210,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) - self.assertAllClose(actual.device_buffers[0].to_py(), expected, + self.assertAllClose(np.asarray(actual.device_buffers[0]), expected, check_dtypes=False) @jtu.with_mesh([('x', 2), ('y', 2)]) @@ -232,13 +232,13 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertLen(actual.device_buffers, 4) split0, split1 = np.split(expected, 2) - self.assertAllClose(actual.device_buffers[0].to_py(), split0, + self.assertAllClose(np.asarray(actual.device_buffers[0]), split0, check_dtypes=False) - self.assertAllClose(actual.device_buffers[1].to_py(), split0, + self.assertAllClose(np.asarray(actual.device_buffers[1]), split0, check_dtypes=False) - self.assertAllClose(actual.device_buffers[2].to_py(), split1, + self.assertAllClose(np.asarray(actual.device_buffers[2]), split1, check_dtypes=False) - self.assertAllClose(actual.device_buffers[3].to_py(), split1, + self.assertAllClose(np.asarray(actual.device_buffers[3]), split1, check_dtypes=False) def testBasic2DWithMeshContextManager(self): @@ -261,13 +261,13 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertLen(actual.device_buffers, 4) split0, split1 = np.split(expected, 2) - self.assertAllClose(actual.device_buffers[0].to_py(), split0, + self.assertAllClose(np.asarray(actual.device_buffers[0]), split0, check_dtypes=False) - self.assertAllClose(actual.device_buffers[1].to_py(), split0, + self.assertAllClose(np.asarray(actual.device_buffers[1]), split0, check_dtypes=False) - self.assertAllClose(actual.device_buffers[2].to_py(), split1, + self.assertAllClose(np.asarray(actual.device_buffers[2]), split1, check_dtypes=False) - self.assertAllClose(actual.device_buffers[3].to_py(), split1, + self.assertAllClose(np.asarray(actual.device_buffers[3]), split1, check_dtypes=False) def testDifferentNestedMesh(self): @@ -318,13 +318,13 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertLen(actual.device_buffers, 4) splits = np.split(expected, 4) - self.assertAllClose(actual.device_buffers[0].to_py(), splits[0], + self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0], check_dtypes=False) - self.assertAllClose(actual.device_buffers[1].to_py(), splits[1], + self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1], check_dtypes=False) - self.assertAllClose(actual.device_buffers[2].to_py(), splits[2], + self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2], check_dtypes=False) - self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], + self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3], check_dtypes=False) @jtu.with_mesh([('x', 2)]) @@ -363,7 +363,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) - self.assertAllClose(actual.device_buffers[0].to_py(), expected, + self.assertAllClose(np.asarray(actual.device_buffers[0]), expected, check_dtypes=False) hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") @@ -390,7 +390,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, array.Array) self.assertLen(actual.addressable_shards, 2) - self.assertAllClose(actual._arrays[0].to_py(), expected, + self.assertAllClose(np.asarray(actual._arrays[0]), expected, check_dtypes=False) hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") @@ -419,7 +419,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, array.Array) self.assertLen(actual.addressable_shards, 2) - self.assertAllClose(actual._arrays[0].to_py(), expected, + self.assertAllClose(np.asarray(actual._arrays[0]), expected, check_dtypes=False) hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") @@ -842,13 +842,13 @@ class PJitTest(jtu.BufferDonationTestCase): ((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {})) splits = np.split(expected, 4) - self.assertAllClose(actual.device_buffers[0].to_py(), splits[0], + self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0], check_dtypes=False) - self.assertAllClose(actual.device_buffers[1].to_py(), splits[1], + self.assertAllClose(np.asarray(actual.device_buffers[1]), splits[1], check_dtypes=False) - self.assertAllClose(actual.device_buffers[2].to_py(), splits[2], + self.assertAllClose(np.asarray(actual.device_buffers[2]), splits[2], check_dtypes=False) - self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], + self.assertAllClose(np.asarray(actual.device_buffers[3]), splits[3], check_dtypes=False) for obj in [lowered, compiled]: diff --git a/tests/pmap_test.py b/tests/pmap_test.py index b903263a3..5e7f4259d 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3016,7 +3016,7 @@ class ShardArgsTest(jtu.JaxTestCase): self.assertEqual(len(bufs), 1) self.assertEqual(len(bufs[0]), nshards) for buf, idx in zip(bufs[0], indices): - self.assertAllClose(buf.to_py(), x[idx], check_dtypes=False) + self.assertAllClose(np.asarray(buf), x[idx], check_dtypes=False) class ArrayPmapTest(jtu.JaxTestCase):