From 97beb01c43b933cf9cd74bfa52d28c576e74b854 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 29 Nov 2023 16:52:09 -0800 Subject: [PATCH] Deprecate the device() method of JAX arrays --- CHANGELOG.md | 5 +- docs/faq.rst | 9 +- jax/_src/array.py | 23 +-- jax/_src/interpreters/pxla.py | 2 +- tests/api_test.py | 61 ++++---- tests/array_interoperability_test.py | 6 +- tests/array_test.py | 4 +- tests/multi_device_test.py | 4 +- tests/multibackend_test.py | 36 ++--- tests/notebooks/colab_cpu.ipynb | 2 +- tests/notebooks/colab_gpu.ipynb | 204 +++++++++++++-------------- tests/pjit_test.py | 22 +-- tests/pmap_test.py | 22 +-- tests/random_test.py | 3 +- 14 files changed, 209 insertions(+), 194 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b32215415..f97d953d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,10 @@ Remember to align the itemized text with the first line of an item within a list but the previous outputs can be recovered this way: * `arr.device_buffer` becomes `arr.addressable_data(0)` * `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]` - + * The `device()` method of JAX arrays deprecated. Depending on the context, it may + be replaced with one of the following: + - {meth}`jax.Array.devices` returns the set of all devices used by the array. + - {attr}`jax.Array.sharding` gives the sharding configuration used by the array. ## jaxlib 0.4.21 diff --git a/docs/faq.rst b/docs/faq.rst index a2db17743..8102469c1 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu" platforms are available in priority order). >>> from jax import numpy as jnp ->>> print(jnp.ones(3).device()) # doctest: +SKIP -cuda:0 +>>> print(jnp.ones(3).devices()) # doctest: +SKIP +{CudaDevice(id=0)} Computations involving uncommitted data are performed on the default device and the results are uncommitted on the default device. @@ -355,8 +355,9 @@ with a ``device`` parameter, in which case the data becomes **committed** to the >>> import jax >>> from jax import device_put ->>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP -cuda:2 +>>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP +>>> print(arr.devices()) # doctest: +SKIP +{CudaDevice(id=2)} Computations involving some committed inputs will happen on the committed device and the result will be committed on the diff --git a/jax/_src/array.py b/jax/_src/array.py index bd12a7f9f..65aba1cee 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -51,6 +51,11 @@ Device = xc.Device Index = tuple[slice, ...] PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this. +def _get_device(a: ArrayImpl) -> Device: + assert len(a.devices()) == 1 + return next(iter(a.devices())) + + class Shard: """A single data shard of an Array. @@ -128,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape): di_map = _cached_index_calc(s, shape) copy_plan = [] for a in arrays: - ind = di_map.get(a.device(), None) + ind = di_map.get(_get_device(a), None) if ind is not None: copy_plan.append((ind, a)) return copy_plan @@ -183,7 +188,7 @@ class ArrayImpl(basearray.Array): "Input buffers to `Array` must have matching dtypes. " f"Got {db.dtype}, expected {self.dtype} for buffer: {db}") - device_id_to_buffer = {db.device().id: db for db in self._arrays} + device_id_to_buffer = {_get_device(db).id: db for db in self._arrays} addressable_dev = self.sharding.addressable_devices if len(self._arrays) != len(addressable_dev): @@ -324,7 +329,7 @@ class ArrayImpl(basearray.Array): if arr_idx is not None: a = self._arrays[arr_idx] return ArrayImpl( - a.aval, SingleDeviceSharding(a.device()), [a], committed=False, + a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False, _skip_checks=True) return lax_numpy._rewriting_take(self, idx) else: @@ -400,7 +405,7 @@ class ArrayImpl(basearray.Array): return DLDeviceType.kDLCPU, 0 elif self.platform() == "gpu": - platform_version = self.device().client.platform_version + platform_version = _get_device(self).client.platform_version if "cuda" in platform_version: dl_device_type = DLDeviceType.kDLCUDA elif "rocm" in platform_version: @@ -409,7 +414,7 @@ class ArrayImpl(basearray.Array): raise ValueError("Unknown GPU platform for __dlpack__: " f"{platform_version}") - local_hardware_id = self.device().local_hardware_id + local_hardware_id = _get_device(self).local_hardware_id if local_hardware_id is None: raise ValueError("Couldn't get local_hardware_id for __dlpack__") @@ -451,6 +456,8 @@ class ArrayImpl(basearray.Array): # TODO(yashkatariya): Remove this method when everyone is using devices(). def device(self) -> Device: + warnings.warn("arr.device() is deprecated. Use arr.devices() instead.", + DeprecationWarning, stacklevel=2) self._check_if_deleted() device_set = self.sharding.device_set if len(device_set) == 1: @@ -499,7 +506,7 @@ class ArrayImpl(basearray.Array): self._check_if_deleted() out = [] for a in self._arrays: - out.append(Shard(a.device(), self.sharding, self.shape, a)) + out.append(Shard(_get_device(a), self.sharding, self.shape, a)) return out @property @@ -514,7 +521,7 @@ class ArrayImpl(basearray.Array): return self.addressable_shards out = [] - device_id_to_buffer = {a.device().id: a for a in self._arrays} + device_id_to_buffer = {_get_device(a).id: a for a in self._arrays} for global_d in self.sharding.device_set: if device_id_to_buffer.get(global_d.id, None) is not None: array = device_id_to_buffer[global_d.id] @@ -835,7 +842,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: - if buf.device() == device: + if buf.devices() == {device}: bufs.append(buf) break else: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 488433a4e..a8bcdac2b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -179,7 +179,7 @@ def batched_device_put(aval: core.ShapedArray, bufs = [x for x, d in safe_zip(xs, devices) if (isinstance(x, array.ArrayImpl) and dispatch.is_single_device_sharding(x.sharding) and - x.device() == d)] + x.devices() == {d})] if len(bufs) == len(xs): return array.ArrayImpl( aval, sharding, bufs, committed=committed, _skip_checks=True) diff --git a/tests/api_test.py b/tests/api_test.py index 05e730432..8d301c742 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -228,7 +228,7 @@ class JitTest(jtu.BufferDonationTestCase): device = jax.devices()[-1] x = jit(lambda x: x, device=device)(3.) _check_instance(self, x) - self.assertEqual(x.device(), device) + self.assertEqual(x.devices(), {device}) @parameterized.named_parameters( ('jit', jax.jit), @@ -239,42 +239,44 @@ class JitTest(jtu.BufferDonationTestCase): if jax.device_count() == 1: raise unittest.SkipTest("Test requires multiple devices") - system_default_device = jnp.add(1, 1).device() + system_default_devices = jnp.add(1, 1).devices() + self.assertLen(system_default_devices, 1) + system_default_device = list(system_default_devices)[0] test_device = jax.devices()[-1] self.assertNotEqual(system_default_device, test_device) f = module(lambda x: x + 1) - self.assertEqual(f(1).device(), system_default_device) + self.assertEqual(f(1).devices(), system_default_devices) with jax.default_device(test_device): - self.assertEqual(jnp.add(1, 1).device(), test_device) - self.assertEqual(f(1).device(), test_device) + self.assertEqual(jnp.add(1, 1).devices(), {test_device}) + self.assertEqual(f(1).devices(), {test_device}) - self.assertEqual(jnp.add(1, 1).device(), system_default_device) - self.assertEqual(f(1).device(), system_default_device) + self.assertEqual(jnp.add(1, 1).devices(), system_default_devices) + self.assertEqual(f(1).devices(), system_default_devices) with jax.default_device(test_device): # Explicit `device` or `backend` argument to jit overrides default_device self.assertEqual( - module(f, device=system_default_device)(1).device(), - system_default_device) + module(f, device=system_default_device)(1).devices(), + system_default_devices) out = module(f, backend="cpu")(1) - self.assertEqual(out.device().platform, "cpu") + self.assertEqual(next(iter(out.devices())).platform, "cpu") # Sticky input device overrides default_device sticky = jax.device_put(1, system_default_device) - self.assertEqual(jnp.add(sticky, 1).device(), system_default_device) - self.assertEqual(f(sticky).device(), system_default_device) + self.assertEqual(jnp.add(sticky, 1).devices(), system_default_devices) + self.assertEqual(f(sticky).devices(), system_default_devices) # Test nested default_devices with jax.default_device(system_default_device): - self.assertEqual(f(1).device(), system_default_device) - self.assertEqual(f(1).device(), test_device) + self.assertEqual(f(1).devices(), system_default_devices) + self.assertEqual(f(1).devices(), {test_device}) # Test a few more non-default_device calls for good luck - self.assertEqual(jnp.add(1, 1).device(), system_default_device) - self.assertEqual(f(sticky).device(), system_default_device) - self.assertEqual(f(1).device(), system_default_device) + self.assertEqual(jnp.add(1, 1).devices(), system_default_devices) + self.assertEqual(f(sticky).devices(), system_default_devices) + self.assertEqual(f(1).devices(), system_default_devices) # TODO(skye): make this work! def test_jit_default_platform(self): @@ -815,8 +817,8 @@ class JitTest(jtu.BufferDonationTestCase): result = jitted_f(1.) result_cpu = jitted_f_cpu(1.) - self.assertEqual(result.device().platform, jtu.device_under_test()) - self.assertEqual(result_cpu.device().platform, "cpu") + self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test()) + self.assertEqual(list(result_cpu.devices())[0].platform, "cpu") @parameterized.named_parameters( ('jit', jax.jit), @@ -1697,7 +1699,7 @@ class APITest(jtu.JaxTestCase): u = jax.device_put(y, jax.devices()[0]) self.assertArraysAllClose(u, y) - self.assertEqual(u.device(), jax.devices()[0]) + self.assertEqual(u.devices(), {jax.devices()[0]}) def test_device_put_sharding_tree(self): if jax.device_count() < 2: @@ -1830,10 +1832,10 @@ class APITest(jtu.JaxTestCase): d1, d2 = jax.local_devices()[:2] data = self.rng().randn(*shape).astype(np.float32) x = api.device_put(data, device=d1) - self.assertEqual(x.device(), d1) + self.assertEqual(x.devices(), {d1}) y = api.device_put(x, device=d2) - self.assertEqual(y.device(), d2) + self.assertEqual(y.devices(), {d2}) np.testing.assert_array_equal(data, np.array(y)) # Make sure these don't crash @@ -1848,11 +1850,11 @@ class APITest(jtu.JaxTestCase): np_arr = np.array([1,2,3]) scalar = 1 device_arr = jnp.array([1,2,3]) - assert device_arr.device() is default_device + assert device_arr.devices() == {default_device} for val in [np_arr, device_arr, scalar]: x = api.device_put(val, device=cpu_device) - self.assertEqual(x.device(), cpu_device) + self.assertEqual(x.devices(), {cpu_device}) @jax.default_matmul_precision("float32") def test_jacobian(self): @@ -3852,21 +3854,22 @@ class APITest(jtu.JaxTestCase): @jtu.skip_on_devices("cpu") def test_default_device(self): - system_default_device = jnp.zeros(2).device() + system_default_devices = jnp.add(1, 1).devices() + self.assertLen(system_default_devices, 1) test_device = jax.devices("cpu")[-1] # Sanity check creating array using system default device - self.assertEqual(jnp.ones(1).device(), system_default_device) + self.assertEqual(jnp.ones(1).devices(), system_default_devices) # Create array with default_device set with jax.default_device(test_device): # Hits cached primitive path - self.assertEqual(jnp.ones(1).device(), test_device) + self.assertEqual(jnp.ones(1).devices(), {test_device}) # Uncached - self.assertEqual(jnp.zeros((1, 2)).device(), test_device) + self.assertEqual(jnp.zeros((1, 2)).devices(), {test_device}) # Test that we can reset to system default device - self.assertEqual(jnp.ones(1).device(), system_default_device) + self.assertEqual(jnp.ones(1).devices(), system_default_devices) def test_dunder_jax_array(self): # https://github.com/google/jax/pull/4725 diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 629634865..3db0fce66 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -77,7 +77,7 @@ class DLPackTest(jtu.JaxTestCase): x = jax.device_put(np, device) dlpack = jax.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) - self.assertEqual(y.device(), device) + self.assertEqual(y.devices(), {device}) self.assertAllClose(np.astype(x.dtype), y) self.assertRaisesRegex(RuntimeError, @@ -97,11 +97,11 @@ class DLPackTest(jtu.JaxTestCase): device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) y = jax.dlpack.from_dlpack(x) - self.assertEqual(y.device(), device) + self.assertEqual(y.devices(), {device}) self.assertAllClose(np.astype(x.dtype), y) # Test we can create multiple arrays z = jax.dlpack.from_dlpack(x) - self.assertEqual(z.device(), device) + self.assertEqual(z.devices(), {device}) self.assertAllClose(np.astype(x.dtype), z) diff --git a/tests/array_test.py b/tests/array_test.py index 05dfb392b..bb2d25a16 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -424,7 +424,7 @@ class JaxArrayTest(jtu.JaxTestCase): x = jnp.array([[1., 0., 0.], [0., 2., 3.]]) y = jax.pmap(jnp.sin)(x) - self.assertArraysEqual([a.device() for a in y], + self.assertArraysEqual([list(a.devices())[0] for a in y], y.sharding._device_assignment, allow_object_dtype=True) @@ -550,7 +550,7 @@ class JaxArrayTest(jtu.JaxTestCase): for i, j in zip(arr, iter(input_data)): self.assertArraysEqual(i, j) - self.assertEqual(i.device(), single_dev[0]) + self.assertEqual(i.devices(), {single_dev[0]}) def test_array_shards_committed(self): if jax.device_count() < 2: diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index c373c8721..1344aba6c 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -61,12 +61,12 @@ class MultiDeviceTest(jtu.JaxTestCase): def assert_committed_to_device(self, data, device): """Asserts that the data is committed to the device.""" self.assertTrue(data._committed) - self.assertEqual(data.device(), device) + self.assertEqual(data.devices(), {device}) def assert_uncommitted_to_device(self, data, device): """Asserts that the data is on the device but not committed to it.""" self.assertFalse(data._committed) - self.assertEqual(data.device(), device) + self.assertEqual(data.devices(), {device}) def test_computation_follows_data(self): if jax.device_count() < 5: diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 945550813..40cbb6630 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -48,7 +48,7 @@ class MultiBackendTest(jtu.JaxTestCase): z = fun(x, y) self.assertAllClose(z, z_host, rtol=1e-2) correct_platform = backend if backend else jtu.device_under_test() - self.assertEqual(z.device().platform, correct_platform) + self.assertEqual(list(z.devices())[0].platform, correct_platform) @jtu.sample_product( ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)] @@ -72,7 +72,7 @@ class MultiBackendTest(jtu.JaxTestCase): z = fun(x, y) self.assertAllClose(z, z_host, rtol=1e-2) correct_platform = outer if outer else jtu.device_under_test() - self.assertEqual(z.device().platform, correct_platform) + self.assertEqual(list(z.devices())[0].platform, correct_platform) @jtu.sample_product( ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'), @@ -116,8 +116,8 @@ class MultiBackendTest(jtu.JaxTestCase): y = npr.uniform(size=(10,10)) z = fun(x, y) w = jnp.sin(z) - self.assertEqual(z.device().platform, backend) - self.assertEqual(w.device().platform, backend) + self.assertEqual(list(z.devices())[0].platform, backend) + self.assertEqual(list(w.devices())[0].platform, backend) @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends def testJitCpu(self): @@ -131,18 +131,18 @@ class MultiBackendTest(jtu.JaxTestCase): b = x + jnp.ones_like(x) c = x + jnp.eye(2) - self.assertEqual(a.device(), jax.devices('cpu')[0]) - self.assertEqual(b.device(), jax.devices('cpu')[0]) - self.assertEqual(c.device(), jax.devices('cpu')[0]) + self.assertEqual(a.devices(), {jax.devices('cpu')[0]}) + self.assertEqual(b.devices(), {jax.devices('cpu')[0]}) + self.assertEqual(c.devices(), {jax.devices('cpu')[0]}) @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends def test_closed_over_values_device_placement(self): # see https://github.com/google/jax/issues/1431 def f(): return jnp.add(3., 4.) - self.assertNotEqual(jax.jit(f)().device(), - jax.devices('cpu')[0]) - self.assertEqual(jax.jit(f, backend='cpu')().device(), - jax.devices('cpu')[0]) + self.assertNotEqual(jax.jit(f)().devices(), + {jax.devices('cpu')[0]}) + self.assertEqual(jax.jit(f, backend='cpu')().devices(), + {jax.devices('cpu')[0]}) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_jit_on_nondefault_backend(self): @@ -154,22 +154,22 @@ class MultiBackendTest(jtu.JaxTestCase): self.assertNotEqual(default_dev.platform, "cpu") data_on_cpu = jax.device_put(1, device=cpus[0]) - self.assertEqual(data_on_cpu.device(), cpus[0]) + self.assertEqual(data_on_cpu.devices(), {cpus[0]}) def my_sin(x): return jnp.sin(x) # jit without any device spec follows the data result1 = jax.jit(my_sin)(2) - self.assertEqual(result1.device(), default_dev) + self.assertEqual(result1.devices(), {default_dev}) result2 = jax.jit(my_sin)(data_on_cpu) - self.assertEqual(result2.device(), cpus[0]) + self.assertEqual(result2.devices(), {cpus[0]}) # jit with `device` spec places the data on the specified device result3 = jax.jit(my_sin, device=cpus[0])(2) - self.assertEqual(result3.device(), cpus[0]) + self.assertEqual(result3.devices(), {cpus[0]}) # jit with `backend` spec places the data on the specified backend result4 = jax.jit(my_sin, backend="cpu")(2) - self.assertEqual(result4.device(), cpus[0]) + self.assertEqual(result4.devices(), {cpus[0]}) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_indexing(self): @@ -178,7 +178,7 @@ class MultiBackendTest(jtu.JaxTestCase): x = jax.device_put(np.ones(2), cpus[0]) y = x[0] - self.assertEqual(y.device(), cpus[0]) + self.assertEqual(y.devices(), {cpus[0]}) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_sum(self): @@ -187,7 +187,7 @@ class MultiBackendTest(jtu.JaxTestCase): x = jax.device_put(np.ones(2), cpus[0]) y = x.sum() - self.assertEqual(y.device(), cpus[0]) + self.assertEqual(y.devices(), {cpus[0]}) if __name__ == "__main__": diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index 7d484b964..2e4b4fdb9 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -98,7 +98,7 @@ "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device()\n", + "device = list(arr.devices())[0]\n", "print(f\"JAX device type: {device}\")\n", "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" ] diff --git a/tests/notebooks/colab_gpu.ipynb b/tests/notebooks/colab_gpu.ipynb index fead8b947..8352bdaf7 100644 --- a/tests/notebooks/colab_gpu.ipynb +++ b/tests/notebooks/colab_gpu.ipynb @@ -1,24 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "JAX Colab GPU Test", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -27,8 +13,8 @@ { "cell_type": "markdown", "metadata": { - "id": "WkadOyTDCAWD", - "colab_type": "text" + "colab_type": "text", + "id": "WkadOyTDCAWD" }, "source": [ "# JAX Colab GPU Test\n", @@ -38,15 +24,27 @@ }, { "cell_type": "code", + "execution_count": 1, "metadata": { - "id": "_tKNrbqqBHwu", - "colab_type": "code", - "outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33", "colab": { "base_uri": "https://localhost:8080/", "height": 68 - } + }, + "colab_type": "code", + "id": "_tKNrbqqBHwu", + "outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gpu-t4-s-kbefivsjoreh\n", + "0.1.64\n", + "0.1.45\n" + ] + } + ], "source": [ "import jax\n", "import jaxlib\n", @@ -54,25 +52,13 @@ "!cat /var/colab/hostname\n", "print(jax.__version__)\n", "print(jaxlib.__version__)" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "text": [ - "gpu-t4-s-kbefivsjoreh\n", - "0.1.64\n", - "0.1.45\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "oqEG21rADO1F", - "colab_type": "text" + "colab_type": "text", + "id": "oqEG21rADO1F" }, "source": [ "## Confirm Device" @@ -80,39 +66,39 @@ }, { "cell_type": "code", + "execution_count": 2, "metadata": { - "colab_type": "code", - "id": "8BwzMYhKGQj6", - "outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f", "colab": { "base_uri": "https://localhost:8080/", "height": 34 - } + }, + "colab_type": "code", + "id": "8BwzMYhKGQj6", + "outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX device type: gpu:0\n" + ] + } + ], "source": [ "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device()\n", + "device = list(arr.devices())[0]\n", "print(f\"JAX device type: {device}\")\n", "assert device.platform == \"gpu\", \"unexpected JAX device type\"" - ], - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "text": [ - "JAX device type: gpu:0\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "z0FUY9yUC4k1", - "colab_type": "text" + "colab_type": "text", + "id": "z0FUY9yUC4k1" }, "source": [ "## Matrix Multiplication" @@ -120,15 +106,25 @@ }, { "cell_type": "code", + "execution_count": 3, "metadata": { - "colab_type": "code", - "id": "eXn8GUl6CG5N", - "outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9", "colab": { "base_uri": "https://localhost:8080/", "height": 34 - } + }, + "colab_type": "code", + "id": "eXn8GUl6CG5N", + "outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0216676\n" + ] + } + ], "source": [ "import jax\n", "import numpy as np\n", @@ -138,23 +134,13 @@ "x = jax.random.normal(key, (3000, 3000))\n", "result = jax.numpy.dot(x, x.T).mean()\n", "print(result)" - ], - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "text": [ - "1.0216676\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "0zTA2Q19DW4G", - "colab_type": "text" + "colab_type": "text", + "id": "0zTA2Q19DW4G" }, "source": [ "## Linear Algebra" @@ -162,15 +148,26 @@ }, { "cell_type": "code", + "execution_count": 4, "metadata": { - "id": "uW9j84_UDYof", - "colab_type": "code", - "outputId": "80069760-12ab-4df2-9f5c-be2536de59b7", "colab": { "base_uri": "https://localhost:8080/", "height": 51 - } + }, + "colab_type": "code", + "id": "uW9j84_UDYof", + "outputId": "80069760-12ab-4df2-9f5c-be2536de59b7" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n", + " 2.866489 1.8229384 1.5478926]\n" + ] + } + ], "source": [ "import jax.numpy as jnp\n", "import jax.random as rand\n", @@ -184,24 +181,13 @@ "assert u.shape == (N, N)\n", "assert vt.shape == (M, M)\n", "print(s)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n", - " 2.866489 1.8229384 1.5478926]\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "jCyKUn4-DCXn", - "colab_type": "text" + "colab_type": "text", + "id": "jCyKUn4-DCXn" }, "source": [ "## XLA Compilation" @@ -209,15 +195,26 @@ }, { "cell_type": "code", + "execution_count": 5, "metadata": { - "colab_type": "code", - "id": "2GOn_HhDPuEn", - "outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8", "colab": { "base_uri": "https://localhost:8080/", "height": 51 - } + }, + "colab_type": "code", + "id": "2GOn_HhDPuEn", + "outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n", + " 0.13093245]\n" + ] + } + ], "source": [ "@jax.jit\n", "def selu(x, alpha=1.67, lmbda=1.05):\n", @@ -225,18 +222,21 @@ "x = jax.random.normal(key, (5000,))\n", "result = selu(x).block_until_ready()\n", "print(result)" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n", - " 0.13093245]\n" - ], - "name": "stdout" - } ] } - ] + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "JAX Colab GPU Test", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b00bcf6b6..7b40816ad 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2539,7 +2539,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertFalse(a._committed) out = f(a, a) self.assertFalse(out._committed) - self.assertEqual(out.device(), jax.devices()[0]) + self.assertEqual(out.devices(), {jax.devices()[0]}) self.assertArraysEqual(out, a * 2) with jax.default_device(jax.devices()[1]): @@ -2547,7 +2547,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertFalse(b._committed) out2 = f(b, b) self.assertFalse(out2._committed) - self.assertEqual(out2.device(), jax.devices()[1]) + self.assertEqual(out2.devices(), {jax.devices()[1]}) self.assertArraysEqual(out2, b * 2) def test_pjit_with_static_argnames(self): @@ -2590,7 +2590,7 @@ class ArrayPjitTest(jtu.JaxTestCase): if jax.device_count() <= 1: self.skipTest('Test requires more >1 device.') - system_default_device = jnp.add(1, 1).device() + system_default_device = list(jnp.add(1, 1).devices())[0] test_device = jax.devices()[-1] f = pjit(lambda x: x + 1) @@ -2733,7 +2733,7 @@ class ArrayPjitTest(jtu.JaxTestCase): return x @ x.T def _check(out, expected_device, expected_out): - self.assertEqual(out.device(), expected_device) + self.assertEqual(out.devices(), {expected_device}) self.assertLen(out.sharding.device_set, 1) self.assertArraysEqual(out, expected_out @ expected_out.T) @@ -2776,14 +2776,14 @@ class ArrayPjitTest(jtu.JaxTestCase): expected_device = jax.devices()[2] final_out = pjit(lambda x: x * 3, device=expected_device)(out) - self.assertEqual(final_out.device(), expected_device) + self.assertEqual(final_out.devices(), {expected_device}) self.assertLen(final_out.sharding.device_set, 1) self.assertArraysEqual(final_out, inp * 6) @jtu.run_on_devices("tpu") def test_pjit_with_backend_arg(self): def _check(out, expected_device, expected_out): - self.assertEqual(out.device(), expected_device) + self.assertEqual(out.devices(), {expected_device}) self.assertLen(out.sharding.device_set, 1) self.assertArraysEqual(out, expected_out) @@ -3403,7 +3403,7 @@ class ArrayPjitTest(jtu.JaxTestCase): y = jax.device_put(x, jax.devices()[1]) out2 = jax.jit(lambda x: x)(y) self.assertIsInstance(out2.sharding, SingleDeviceSharding) - self.assertEqual(out2.device(), jax.devices()[1]) + self.assertEqual(out2.devices(), {jax.devices()[1]}) out3 = jax.jit(lambda x: x * 2)(x) self.assertIsInstance(out3.sharding, SingleDeviceSharding) @@ -3411,7 +3411,7 @@ class ArrayPjitTest(jtu.JaxTestCase): out4 = jax.jit(lambda x: x * 3, out_shardings=SingleDeviceSharding(jax.devices()[1]))(x) self.assertIsInstance(out4.sharding, SingleDeviceSharding) - self.assertEqual(out4.device(), jax.devices()[1]) + self.assertEqual(out4.devices(), {jax.devices()[1]}) def test_none_out_sharding(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) @@ -3449,7 +3449,7 @@ class ArrayPjitTest(jtu.JaxTestCase): arr4 = jax.device_put(jnp.arange(8), jax.devices()[1]) out4 = jnp.copy(arr4) self.assertIsInstance(out4.sharding, SingleDeviceSharding) - self.assertEqual(out4.device(), jax.devices()[1]) + self.assertEqual(out4.devices(), {jax.devices()[1]}) def test_get_indices_cache(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -3506,12 +3506,12 @@ class ArrayPjitTest(jtu.JaxTestCase): # Fill up the to_gspmd_sharding cache so that the next jit will miss it. out = jax.jit(identity, in_shardings=SingleDeviceSharding(jax.devices()[0]))(np_inp) - self.assertEqual(out.device(), jax.devices()[0]) + self.assertEqual(out.devices(), {jax.devices()[0]}) self.assertArraysEqual(out, np_inp) out2 = jax.jit(identity, device=jax.devices()[0])( jax.device_put(np_inp, NamedSharding(mesh, P('x')))) - self.assertEqual(out2.device(), jax.devices()[0]) + self.assertEqual(out2.devices(), {jax.devices()[0]}) self.assertArraysEqual(out2, np_inp) def test_jit_submhlo_cached(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index ef2816ff0..f79245f68 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -147,12 +147,12 @@ class PythonPmapTest(jtu.JaxTestCase): view = jnp.array(buf, copy=False) self.assertArraysEqual(sda[-1], view) - self.assertEqual(buf.device(), view.device()) + self.assertSetEqual(buf.devices(), view.devices()) self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer()) copy = jnp.array(buf, copy=True) self.assertArraysEqual(sda[-1], copy) - self.assertEqual(buf.device(), copy.device()) + self.assertSetEqual(buf.devices(), copy.devices()) self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer()) def _getMeshShape(self, device_mesh_shape): @@ -869,7 +869,7 @@ class PythonPmapTest(jtu.JaxTestCase): # test that we can handle device movement on dispatch bufs = y._arrays[::-1] sharding = jax.sharding.PmapSharding( - [b.device() for b in bufs], y.sharding.sharding_spec) + [list(b.devices())[0] for b in bufs], y.sharding.sharding_spec) y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs) z = f(y) self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False) @@ -2769,7 +2769,7 @@ class ArrayTest(jtu.JaxTestCase): self.assertEqual(s.replica_id, 0) buffers = getattr(y, '_arrays') self.assertEqual(len(buffers), len(devices)) - self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices))) + self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices))) self.assertArraysEqual(y, jnp.stack(x)) def test_device_put_sharded_pytree(self): @@ -2781,12 +2781,12 @@ class ArrayTest(jtu.JaxTestCase): self.assertIsInstance(y1, array.ArrayImpl) self.assertArraysEqual(y1, jnp.array([a for a, _ in x])) y1_buffers = getattr(y1, '_arrays') - self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices))) + self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices))) self.assertIsInstance(y2, array.ArrayImpl) self.assertArraysEqual(y2, jnp.vstack([b for _, b in x])) y2_buffers = getattr(y2, '_arrays') - self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices))) + self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices))) def test_device_put_replicated(self): devices = jax.local_devices() @@ -2796,7 +2796,7 @@ class ArrayTest(jtu.JaxTestCase): self.assertIsInstance(y, array.ArrayImpl) buffers = getattr(y, '_arrays') self.assertEqual(len(buffers), len(devices)) - self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices))) + self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices))) self.assertArraysEqual(y, np.stack([x for _ in devices])) def test_device_put_replicated_pytree(self): @@ -2809,13 +2809,13 @@ class ArrayTest(jtu.JaxTestCase): self.assertIsInstance(y1, array.ArrayImpl) y1_buffers = getattr(y1, '_arrays') self.assertEqual(len(y1_buffers), len(devices)) - self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices))) + self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices))) self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices])) self.assertIsInstance(y2, array.ArrayImpl) y2_buffers = getattr(y2, '_arrays') self.assertEqual(len(y2_buffers), len(devices)) - self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices))) + self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices))) self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices])) def test_repr(self): @@ -3127,8 +3127,8 @@ class ArrayPmapTest(jtu.JaxTestCase): self.skipTest('Test requires >= 2 devices.') def amap(f, xs): - ys = [f(jax.device_put(x, x.device())) for x in xs] - return jax.device_put_sharded(ys, [y.device() for y in ys]) + ys = [f(jax.device_put(x, list(x.devices())[0])) for x in xs] + return jax.device_put_sharded(ys, [list(y.devices())[0] for y in ys]) # leading axis is batch dim (i.e. mapped/parallel dim), of size 2 x = jnp.array([[1., 0., 0.], diff --git a/tests/random_test.py b/tests/random_test.py index 0106923e5..52ce462e3 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1026,7 +1026,8 @@ class KeyArrayTest(jtu.JaxTestCase): self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable) self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated) - self.assertEqual(key.device(), key._base_array.device()) + with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"): + self.assertEqual(key.device(), key._base_array.device()) self.assertEqual(key.devices(), key._base_array.devices()) self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes) self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)