Deprecate the device() method of JAX arrays

This commit is contained in:
Jake VanderPlas 2023-11-29 16:52:09 -08:00
parent 4de07b3f62
commit 97beb01c43
14 changed files with 209 additions and 194 deletions

View File

@ -32,7 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
* The `device()` method of JAX arrays deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.
## jaxlib 0.4.21

View File

@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
platforms are available in priority order).
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device()) # doctest: +SKIP
cuda:0
>>> print(jnp.ones(3).devices()) # doctest: +SKIP
{CudaDevice(id=0)}
Computations involving uncommitted data are performed on the default
device and the results are uncommitted on the default device.
@ -355,8 +355,9 @@ with a ``device`` parameter, in which case the data becomes **committed** to the
>>> import jax
>>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP
cuda:2
>>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP
>>> print(arr.devices()) # doctest: +SKIP
{CudaDevice(id=2)}
Computations involving some committed inputs will happen on the
committed device and the result will be committed on the

View File

@ -51,6 +51,11 @@ Device = xc.Device
Index = tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
assert len(a.devices()) == 1
return next(iter(a.devices()))
class Shard:
"""A single data shard of an Array.
@ -128,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
di_map = _cached_index_calc(s, shape)
copy_plan = []
for a in arrays:
ind = di_map.get(a.device(), None)
ind = di_map.get(_get_device(a), None)
if ind is not None:
copy_plan.append((ind, a))
return copy_plan
@ -183,7 +188,7 @@ class ArrayImpl(basearray.Array):
"Input buffers to `Array` must have matching dtypes. "
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")
device_id_to_buffer = {db.device().id: db for db in self._arrays}
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
@ -324,7 +329,7 @@ class ArrayImpl(basearray.Array):
if arr_idx is not None:
a = self._arrays[arr_idx]
return ArrayImpl(
a.aval, SingleDeviceSharding(a.device()), [a], committed=False,
a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False,
_skip_checks=True)
return lax_numpy._rewriting_take(self, idx)
else:
@ -400,7 +405,7 @@ class ArrayImpl(basearray.Array):
return DLDeviceType.kDLCPU, 0
elif self.platform() == "gpu":
platform_version = self.device().client.platform_version
platform_version = _get_device(self).client.platform_version
if "cuda" in platform_version:
dl_device_type = DLDeviceType.kDLCUDA
elif "rocm" in platform_version:
@ -409,7 +414,7 @@ class ArrayImpl(basearray.Array):
raise ValueError("Unknown GPU platform for __dlpack__: "
f"{platform_version}")
local_hardware_id = self.device().local_hardware_id
local_hardware_id = _get_device(self).local_hardware_id
if local_hardware_id is None:
raise ValueError("Couldn't get local_hardware_id for __dlpack__")
@ -451,6 +456,8 @@ class ArrayImpl(basearray.Array):
# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
device_set = self.sharding.device_set
if len(device_set) == 1:
@ -499,7 +506,7 @@ class ArrayImpl(basearray.Array):
self._check_if_deleted()
out = []
for a in self._arrays:
out.append(Shard(a.device(), self.sharding, self.shape, a))
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
return out
@property
@ -514,7 +521,7 @@ class ArrayImpl(basearray.Array):
return self.addressable_shards
out = []
device_id_to_buffer = {a.device().id: a for a in self._arrays}
device_id_to_buffer = {_get_device(a).id: a for a in self._arrays}
for global_d in self.sharding.device_set:
if device_id_to_buffer.get(global_d.id, None) is not None:
array = device_id_to_buffer[global_d.id]
@ -835,7 +842,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
if buf.device() == device:
if buf.devices() == {device}:
bufs.append(buf)
break
else:

View File

@ -179,7 +179,7 @@ def batched_device_put(aval: core.ShapedArray,
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)]
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)

View File

@ -228,7 +228,7 @@ class JitTest(jtu.BufferDonationTestCase):
device = jax.devices()[-1]
x = jit(lambda x: x, device=device)(3.)
_check_instance(self, x)
self.assertEqual(x.device(), device)
self.assertEqual(x.devices(), {device})
@parameterized.named_parameters(
('jit', jax.jit),
@ -239,42 +239,44 @@ class JitTest(jtu.BufferDonationTestCase):
if jax.device_count() == 1:
raise unittest.SkipTest("Test requires multiple devices")
system_default_device = jnp.add(1, 1).device()
system_default_devices = jnp.add(1, 1).devices()
self.assertLen(system_default_devices, 1)
system_default_device = list(system_default_devices)[0]
test_device = jax.devices()[-1]
self.assertNotEqual(system_default_device, test_device)
f = module(lambda x: x + 1)
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(f(1).devices(), system_default_devices)
with jax.default_device(test_device):
self.assertEqual(jnp.add(1, 1).device(), test_device)
self.assertEqual(f(1).device(), test_device)
self.assertEqual(jnp.add(1, 1).devices(), {test_device})
self.assertEqual(f(1).devices(), {test_device})
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
self.assertEqual(f(1).devices(), system_default_devices)
with jax.default_device(test_device):
# Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual(
module(f, device=system_default_device)(1).device(),
system_default_device)
module(f, device=system_default_device)(1).devices(),
system_default_devices)
out = module(f, backend="cpu")(1)
self.assertEqual(out.device().platform, "cpu")
self.assertEqual(next(iter(out.devices())).platform, "cpu")
# Sticky input device overrides default_device
sticky = jax.device_put(1, system_default_device)
self.assertEqual(jnp.add(sticky, 1).device(), system_default_device)
self.assertEqual(f(sticky).device(), system_default_device)
self.assertEqual(jnp.add(sticky, 1).devices(), system_default_devices)
self.assertEqual(f(sticky).devices(), system_default_devices)
# Test nested default_devices
with jax.default_device(system_default_device):
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(f(1).device(), test_device)
self.assertEqual(f(1).devices(), system_default_devices)
self.assertEqual(f(1).devices(), {test_device})
# Test a few more non-default_device calls for good luck
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
self.assertEqual(f(sticky).device(), system_default_device)
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
self.assertEqual(f(sticky).devices(), system_default_devices)
self.assertEqual(f(1).devices(), system_default_devices)
# TODO(skye): make this work!
def test_jit_default_platform(self):
@ -815,8 +817,8 @@ class JitTest(jtu.BufferDonationTestCase):
result = jitted_f(1.)
result_cpu = jitted_f_cpu(1.)
self.assertEqual(result.device().platform, jtu.device_under_test())
self.assertEqual(result_cpu.device().platform, "cpu")
self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test())
self.assertEqual(list(result_cpu.devices())[0].platform, "cpu")
@parameterized.named_parameters(
('jit', jax.jit),
@ -1697,7 +1699,7 @@ class APITest(jtu.JaxTestCase):
u = jax.device_put(y, jax.devices()[0])
self.assertArraysAllClose(u, y)
self.assertEqual(u.device(), jax.devices()[0])
self.assertEqual(u.devices(), {jax.devices()[0]})
def test_device_put_sharding_tree(self):
if jax.device_count() < 2:
@ -1830,10 +1832,10 @@ class APITest(jtu.JaxTestCase):
d1, d2 = jax.local_devices()[:2]
data = self.rng().randn(*shape).astype(np.float32)
x = api.device_put(data, device=d1)
self.assertEqual(x.device(), d1)
self.assertEqual(x.devices(), {d1})
y = api.device_put(x, device=d2)
self.assertEqual(y.device(), d2)
self.assertEqual(y.devices(), {d2})
np.testing.assert_array_equal(data, np.array(y))
# Make sure these don't crash
@ -1848,11 +1850,11 @@ class APITest(jtu.JaxTestCase):
np_arr = np.array([1,2,3])
scalar = 1
device_arr = jnp.array([1,2,3])
assert device_arr.device() is default_device
assert device_arr.devices() == {default_device}
for val in [np_arr, device_arr, scalar]:
x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device(), cpu_device)
self.assertEqual(x.devices(), {cpu_device})
@jax.default_matmul_precision("float32")
def test_jacobian(self):
@ -3852,21 +3854,22 @@ class APITest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu")
def test_default_device(self):
system_default_device = jnp.zeros(2).device()
system_default_devices = jnp.add(1, 1).devices()
self.assertLen(system_default_devices, 1)
test_device = jax.devices("cpu")[-1]
# Sanity check creating array using system default device
self.assertEqual(jnp.ones(1).device(), system_default_device)
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
# Create array with default_device set
with jax.default_device(test_device):
# Hits cached primitive path
self.assertEqual(jnp.ones(1).device(), test_device)
self.assertEqual(jnp.ones(1).devices(), {test_device})
# Uncached
self.assertEqual(jnp.zeros((1, 2)).device(), test_device)
self.assertEqual(jnp.zeros((1, 2)).devices(), {test_device})
# Test that we can reset to system default device
self.assertEqual(jnp.ones(1).device(), system_default_device)
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
def test_dunder_jax_array(self):
# https://github.com/google/jax/pull/4725

View File

@ -77,7 +77,7 @@ class DLPackTest(jtu.JaxTestCase):
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.device(), device)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)
self.assertRaisesRegex(RuntimeError,
@ -97,11 +97,11 @@ class DLPackTest(jtu.JaxTestCase):
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
y = jax.dlpack.from_dlpack(x)
self.assertEqual(y.device(), device)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)
# Test we can create multiple arrays
z = jax.dlpack.from_dlpack(x)
self.assertEqual(z.device(), device)
self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)

View File

@ -424,7 +424,7 @@ class JaxArrayTest(jtu.JaxTestCase):
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y],
self.assertArraysEqual([list(a.devices())[0] for a in y],
y.sharding._device_assignment,
allow_object_dtype=True)
@ -550,7 +550,7 @@ class JaxArrayTest(jtu.JaxTestCase):
for i, j in zip(arr, iter(input_data)):
self.assertArraysEqual(i, j)
self.assertEqual(i.device(), single_dev[0])
self.assertEqual(i.devices(), {single_dev[0]})
def test_array_shards_committed(self):
if jax.device_count() < 2:

View File

@ -61,12 +61,12 @@ class MultiDeviceTest(jtu.JaxTestCase):
def assert_committed_to_device(self, data, device):
"""Asserts that the data is committed to the device."""
self.assertTrue(data._committed)
self.assertEqual(data.device(), device)
self.assertEqual(data.devices(), {device})
def assert_uncommitted_to_device(self, data, device):
"""Asserts that the data is on the device but not committed to it."""
self.assertFalse(data._committed)
self.assertEqual(data.device(), device)
self.assertEqual(data.devices(), {device})
def test_computation_follows_data(self):
if jax.device_count() < 5:

View File

@ -48,7 +48,7 @@ class MultiBackendTest(jtu.JaxTestCase):
z = fun(x, y)
self.assertAllClose(z, z_host, rtol=1e-2)
correct_platform = backend if backend else jtu.device_under_test()
self.assertEqual(z.device().platform, correct_platform)
self.assertEqual(list(z.devices())[0].platform, correct_platform)
@jtu.sample_product(
ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)]
@ -72,7 +72,7 @@ class MultiBackendTest(jtu.JaxTestCase):
z = fun(x, y)
self.assertAllClose(z, z_host, rtol=1e-2)
correct_platform = outer if outer else jtu.device_under_test()
self.assertEqual(z.device().platform, correct_platform)
self.assertEqual(list(z.devices())[0].platform, correct_platform)
@jtu.sample_product(
ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'),
@ -116,8 +116,8 @@ class MultiBackendTest(jtu.JaxTestCase):
y = npr.uniform(size=(10,10))
z = fun(x, y)
w = jnp.sin(z)
self.assertEqual(z.device().platform, backend)
self.assertEqual(w.device().platform, backend)
self.assertEqual(list(z.devices())[0].platform, backend)
self.assertEqual(list(w.devices())[0].platform, backend)
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
def testJitCpu(self):
@ -131,18 +131,18 @@ class MultiBackendTest(jtu.JaxTestCase):
b = x + jnp.ones_like(x)
c = x + jnp.eye(2)
self.assertEqual(a.device(), jax.devices('cpu')[0])
self.assertEqual(b.device(), jax.devices('cpu')[0])
self.assertEqual(c.device(), jax.devices('cpu')[0])
self.assertEqual(a.devices(), {jax.devices('cpu')[0]})
self.assertEqual(b.devices(), {jax.devices('cpu')[0]})
self.assertEqual(c.devices(), {jax.devices('cpu')[0]})
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
def test_closed_over_values_device_placement(self):
# see https://github.com/google/jax/issues/1431
def f(): return jnp.add(3., 4.)
self.assertNotEqual(jax.jit(f)().device(),
jax.devices('cpu')[0])
self.assertEqual(jax.jit(f, backend='cpu')().device(),
jax.devices('cpu')[0])
self.assertNotEqual(jax.jit(f)().devices(),
{jax.devices('cpu')[0]})
self.assertEqual(jax.jit(f, backend='cpu')().devices(),
{jax.devices('cpu')[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_jit_on_nondefault_backend(self):
@ -154,22 +154,22 @@ class MultiBackendTest(jtu.JaxTestCase):
self.assertNotEqual(default_dev.platform, "cpu")
data_on_cpu = jax.device_put(1, device=cpus[0])
self.assertEqual(data_on_cpu.device(), cpus[0])
self.assertEqual(data_on_cpu.devices(), {cpus[0]})
def my_sin(x): return jnp.sin(x)
# jit without any device spec follows the data
result1 = jax.jit(my_sin)(2)
self.assertEqual(result1.device(), default_dev)
self.assertEqual(result1.devices(), {default_dev})
result2 = jax.jit(my_sin)(data_on_cpu)
self.assertEqual(result2.device(), cpus[0])
self.assertEqual(result2.devices(), {cpus[0]})
# jit with `device` spec places the data on the specified device
result3 = jax.jit(my_sin, device=cpus[0])(2)
self.assertEqual(result3.device(), cpus[0])
self.assertEqual(result3.devices(), {cpus[0]})
# jit with `backend` spec places the data on the specified backend
result4 = jax.jit(my_sin, backend="cpu")(2)
self.assertEqual(result4.device(), cpus[0])
self.assertEqual(result4.devices(), {cpus[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_indexing(self):
@ -178,7 +178,7 @@ class MultiBackendTest(jtu.JaxTestCase):
x = jax.device_put(np.ones(2), cpus[0])
y = x[0]
self.assertEqual(y.device(), cpus[0])
self.assertEqual(y.devices(), {cpus[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_sum(self):
@ -187,7 +187,7 @@ class MultiBackendTest(jtu.JaxTestCase):
x = jax.device_put(np.ones(2), cpus[0])
y = x.sum()
self.assertEqual(y.device(), cpus[0])
self.assertEqual(y.devices(), {cpus[0]})
if __name__ == "__main__":

View File

@ -98,7 +98,7 @@
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device()\n",
"device = list(arr.devices())[0]\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
]

View File

@ -1,24 +1,10 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX Colab GPU Test",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_gpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@ -27,8 +13,8 @@
{
"cell_type": "markdown",
"metadata": {
"id": "WkadOyTDCAWD",
"colab_type": "text"
"colab_type": "text",
"id": "WkadOyTDCAWD"
},
"source": [
"# JAX Colab GPU Test\n",
@ -38,15 +24,27 @@
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "_tKNrbqqBHwu",
"colab_type": "code",
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"colab_type": "code",
"id": "_tKNrbqqBHwu",
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpu-t4-s-kbefivsjoreh\n",
"0.1.64\n",
"0.1.45\n"
]
}
],
"source": [
"import jax\n",
"import jaxlib\n",
@ -54,25 +52,13 @@
"!cat /var/colab/hostname\n",
"print(jax.__version__)\n",
"print(jaxlib.__version__)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"gpu-t4-s-kbefivsjoreh\n",
"0.1.64\n",
"0.1.45\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oqEG21rADO1F",
"colab_type": "text"
"colab_type": "text",
"id": "oqEG21rADO1F"
},
"source": [
"## Confirm Device"
@ -80,39 +66,39 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX device type: gpu:0\n"
]
}
],
"source": [
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device()\n",
"device = list(arr.devices())[0]\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"gpu\", \"unexpected JAX device type\""
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"JAX device type: gpu:0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0FUY9yUC4k1",
"colab_type": "text"
"colab_type": "text",
"id": "z0FUY9yUC4k1"
},
"source": [
"## Matrix Multiplication"
@ -120,15 +106,25 @@
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0216676\n"
]
}
],
"source": [
"import jax\n",
"import numpy as np\n",
@ -138,23 +134,13 @@
"x = jax.random.normal(key, (3000, 3000))\n",
"result = jax.numpy.dot(x, x.T).mean()\n",
"print(result)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"1.0216676\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0zTA2Q19DW4G",
"colab_type": "text"
"colab_type": "text",
"id": "0zTA2Q19DW4G"
},
"source": [
"## Linear Algebra"
@ -162,15 +148,26 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "uW9j84_UDYof",
"colab_type": "code",
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"colab_type": "code",
"id": "uW9j84_UDYof",
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
" 2.866489 1.8229384 1.5478926]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as rand\n",
@ -184,24 +181,13 @@
"assert u.shape == (N, N)\n",
"assert vt.shape == (M, M)\n",
"print(s)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
" 2.866489 1.8229384 1.5478926]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCyKUn4-DCXn",
"colab_type": "text"
"colab_type": "text",
"id": "jCyKUn4-DCXn"
},
"source": [
"## XLA Compilation"
@ -209,15 +195,26 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
" 0.13093245]\n"
]
}
],
"source": [
"@jax.jit\n",
"def selu(x, alpha=1.67, lmbda=1.05):\n",
@ -225,18 +222,21 @@
"x = jax.random.normal(key, (5000,))\n",
"result = selu(x).block_until_ready()\n",
"print(result)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
" 0.13093245]\n"
],
"name": "stdout"
}
]
}
]
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "JAX Colab GPU Test",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -2539,7 +2539,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertFalse(a._committed)
out = f(a, a)
self.assertFalse(out._committed)
self.assertEqual(out.device(), jax.devices()[0])
self.assertEqual(out.devices(), {jax.devices()[0]})
self.assertArraysEqual(out, a * 2)
with jax.default_device(jax.devices()[1]):
@ -2547,7 +2547,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertFalse(b._committed)
out2 = f(b, b)
self.assertFalse(out2._committed)
self.assertEqual(out2.device(), jax.devices()[1])
self.assertEqual(out2.devices(), {jax.devices()[1]})
self.assertArraysEqual(out2, b * 2)
def test_pjit_with_static_argnames(self):
@ -2590,7 +2590,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
if jax.device_count() <= 1:
self.skipTest('Test requires more >1 device.')
system_default_device = jnp.add(1, 1).device()
system_default_device = list(jnp.add(1, 1).devices())[0]
test_device = jax.devices()[-1]
f = pjit(lambda x: x + 1)
@ -2733,7 +2733,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return x @ x.T
def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device)
self.assertEqual(out.devices(), {expected_device})
self.assertLen(out.sharding.device_set, 1)
self.assertArraysEqual(out, expected_out @ expected_out.T)
@ -2776,14 +2776,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
expected_device = jax.devices()[2]
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
self.assertEqual(final_out.device(), expected_device)
self.assertEqual(final_out.devices(), {expected_device})
self.assertLen(final_out.sharding.device_set, 1)
self.assertArraysEqual(final_out, inp * 6)
@jtu.run_on_devices("tpu")
def test_pjit_with_backend_arg(self):
def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device)
self.assertEqual(out.devices(), {expected_device})
self.assertLen(out.sharding.device_set, 1)
self.assertArraysEqual(out, expected_out)
@ -3403,7 +3403,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
y = jax.device_put(x, jax.devices()[1])
out2 = jax.jit(lambda x: x)(y)
self.assertIsInstance(out2.sharding, SingleDeviceSharding)
self.assertEqual(out2.device(), jax.devices()[1])
self.assertEqual(out2.devices(), {jax.devices()[1]})
out3 = jax.jit(lambda x: x * 2)(x)
self.assertIsInstance(out3.sharding, SingleDeviceSharding)
@ -3411,7 +3411,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out4 = jax.jit(lambda x: x * 3,
out_shardings=SingleDeviceSharding(jax.devices()[1]))(x)
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1])
self.assertEqual(out4.devices(), {jax.devices()[1]})
def test_none_out_sharding(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
@ -3449,7 +3449,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
arr4 = jax.device_put(jnp.arange(8), jax.devices()[1])
out4 = jnp.copy(arr4)
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1])
self.assertEqual(out4.devices(), {jax.devices()[1]})
def test_get_indices_cache(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -3506,12 +3506,12 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Fill up the to_gspmd_sharding cache so that the next jit will miss it.
out = jax.jit(identity,
in_shardings=SingleDeviceSharding(jax.devices()[0]))(np_inp)
self.assertEqual(out.device(), jax.devices()[0])
self.assertEqual(out.devices(), {jax.devices()[0]})
self.assertArraysEqual(out, np_inp)
out2 = jax.jit(identity, device=jax.devices()[0])(
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
self.assertEqual(out2.device(), jax.devices()[0])
self.assertEqual(out2.devices(), {jax.devices()[0]})
self.assertArraysEqual(out2, np_inp)
def test_jit_submhlo_cached(self):

View File

@ -147,12 +147,12 @@ class PythonPmapTest(jtu.JaxTestCase):
view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view)
self.assertEqual(buf.device(), view.device())
self.assertSetEqual(buf.devices(), view.devices())
self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer())
copy = jnp.array(buf, copy=True)
self.assertArraysEqual(sda[-1], copy)
self.assertEqual(buf.device(), copy.device())
self.assertSetEqual(buf.devices(), copy.devices())
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
def _getMeshShape(self, device_mesh_shape):
@ -869,7 +869,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# test that we can handle device movement on dispatch
bufs = y._arrays[::-1]
sharding = jax.sharding.PmapSharding(
[b.device() for b in bufs], y.sharding.sharding_spec)
[list(b.devices())[0] for b in bufs], y.sharding.sharding_spec)
y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs)
z = f(y)
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
@ -2769,7 +2769,7 @@ class ArrayTest(jtu.JaxTestCase):
self.assertEqual(s.replica_id, 0)
buffers = getattr(y, '_arrays')
self.assertEqual(len(buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices)))
self.assertArraysEqual(y, jnp.stack(x))
def test_device_put_sharded_pytree(self):
@ -2781,12 +2781,12 @@ class ArrayTest(jtu.JaxTestCase):
self.assertIsInstance(y1, array.ArrayImpl)
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
y1_buffers = getattr(y1, '_arrays')
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices)))
self.assertIsInstance(y2, array.ArrayImpl)
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
y2_buffers = getattr(y2, '_arrays')
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices)))
def test_device_put_replicated(self):
devices = jax.local_devices()
@ -2796,7 +2796,7 @@ class ArrayTest(jtu.JaxTestCase):
self.assertIsInstance(y, array.ArrayImpl)
buffers = getattr(y, '_arrays')
self.assertEqual(len(buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices)))
self.assertArraysEqual(y, np.stack([x for _ in devices]))
def test_device_put_replicated_pytree(self):
@ -2809,13 +2809,13 @@ class ArrayTest(jtu.JaxTestCase):
self.assertIsInstance(y1, array.ArrayImpl)
y1_buffers = getattr(y1, '_arrays')
self.assertEqual(len(y1_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices)))
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
self.assertIsInstance(y2, array.ArrayImpl)
y2_buffers = getattr(y2, '_arrays')
self.assertEqual(len(y2_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices)))
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
def test_repr(self):
@ -3127,8 +3127,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.skipTest('Test requires >= 2 devices.')
def amap(f, xs):
ys = [f(jax.device_put(x, x.device())) for x in xs]
return jax.device_put_sharded(ys, [y.device() for y in ys])
ys = [f(jax.device_put(x, list(x.devices())[0])) for x in xs]
return jax.device_put_sharded(ys, [list(y.devices())[0] for y in ys])
# leading axis is batch dim (i.e. mapped/parallel dim), of size 2
x = jnp.array([[1., 0., 0.],

View File

@ -1026,7 +1026,8 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
self.assertEqual(key.device(), key._base_array.device())
with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
self.assertEqual(key.device(), key._base_array.device())
self.assertEqual(key.devices(), key._base_array.devices())
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)