mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate the device() method of JAX arrays
This commit is contained in:
parent
4de07b3f62
commit
97beb01c43
@ -32,7 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
but the previous outputs can be recovered this way:
|
||||
* `arr.device_buffer` becomes `arr.addressable_data(0)`
|
||||
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
|
||||
|
||||
* The `device()` method of JAX arrays deprecated. Depending on the context, it may
|
||||
be replaced with one of the following:
|
||||
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
|
||||
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.
|
||||
|
||||
## jaxlib 0.4.21
|
||||
|
||||
|
@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
|
||||
platforms are available in priority order).
|
||||
|
||||
>>> from jax import numpy as jnp
|
||||
>>> print(jnp.ones(3).device()) # doctest: +SKIP
|
||||
cuda:0
|
||||
>>> print(jnp.ones(3).devices()) # doctest: +SKIP
|
||||
{CudaDevice(id=0)}
|
||||
|
||||
Computations involving uncommitted data are performed on the default
|
||||
device and the results are uncommitted on the default device.
|
||||
@ -355,8 +355,9 @@ with a ``device`` parameter, in which case the data becomes **committed** to the
|
||||
|
||||
>>> import jax
|
||||
>>> from jax import device_put
|
||||
>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP
|
||||
cuda:2
|
||||
>>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP
|
||||
>>> print(arr.devices()) # doctest: +SKIP
|
||||
{CudaDevice(id=2)}
|
||||
|
||||
Computations involving some committed inputs will happen on the
|
||||
committed device and the result will be committed on the
|
||||
|
@ -51,6 +51,11 @@ Device = xc.Device
|
||||
Index = tuple[slice, ...]
|
||||
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
def _get_device(a: ArrayImpl) -> Device:
|
||||
assert len(a.devices()) == 1
|
||||
return next(iter(a.devices()))
|
||||
|
||||
|
||||
class Shard:
|
||||
"""A single data shard of an Array.
|
||||
|
||||
@ -128,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
|
||||
di_map = _cached_index_calc(s, shape)
|
||||
copy_plan = []
|
||||
for a in arrays:
|
||||
ind = di_map.get(a.device(), None)
|
||||
ind = di_map.get(_get_device(a), None)
|
||||
if ind is not None:
|
||||
copy_plan.append((ind, a))
|
||||
return copy_plan
|
||||
@ -183,7 +188,7 @@ class ArrayImpl(basearray.Array):
|
||||
"Input buffers to `Array` must have matching dtypes. "
|
||||
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")
|
||||
|
||||
device_id_to_buffer = {db.device().id: db for db in self._arrays}
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
|
||||
|
||||
addressable_dev = self.sharding.addressable_devices
|
||||
if len(self._arrays) != len(addressable_dev):
|
||||
@ -324,7 +329,7 @@ class ArrayImpl(basearray.Array):
|
||||
if arr_idx is not None:
|
||||
a = self._arrays[arr_idx]
|
||||
return ArrayImpl(
|
||||
a.aval, SingleDeviceSharding(a.device()), [a], committed=False,
|
||||
a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False,
|
||||
_skip_checks=True)
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
else:
|
||||
@ -400,7 +405,7 @@ class ArrayImpl(basearray.Array):
|
||||
return DLDeviceType.kDLCPU, 0
|
||||
|
||||
elif self.platform() == "gpu":
|
||||
platform_version = self.device().client.platform_version
|
||||
platform_version = _get_device(self).client.platform_version
|
||||
if "cuda" in platform_version:
|
||||
dl_device_type = DLDeviceType.kDLCUDA
|
||||
elif "rocm" in platform_version:
|
||||
@ -409,7 +414,7 @@ class ArrayImpl(basearray.Array):
|
||||
raise ValueError("Unknown GPU platform for __dlpack__: "
|
||||
f"{platform_version}")
|
||||
|
||||
local_hardware_id = self.device().local_hardware_id
|
||||
local_hardware_id = _get_device(self).local_hardware_id
|
||||
if local_hardware_id is None:
|
||||
raise ValueError("Couldn't get local_hardware_id for __dlpack__")
|
||||
|
||||
@ -451,6 +456,8 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
# TODO(yashkatariya): Remove this method when everyone is using devices().
|
||||
def device(self) -> Device:
|
||||
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
device_set = self.sharding.device_set
|
||||
if len(device_set) == 1:
|
||||
@ -499,7 +506,7 @@ class ArrayImpl(basearray.Array):
|
||||
self._check_if_deleted()
|
||||
out = []
|
||||
for a in self._arrays:
|
||||
out.append(Shard(a.device(), self.sharding, self.shape, a))
|
||||
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
|
||||
return out
|
||||
|
||||
@property
|
||||
@ -514,7 +521,7 @@ class ArrayImpl(basearray.Array):
|
||||
return self.addressable_shards
|
||||
|
||||
out = []
|
||||
device_id_to_buffer = {a.device().id: a for a in self._arrays}
|
||||
device_id_to_buffer = {_get_device(a).id: a for a in self._arrays}
|
||||
for global_d in self.sharding.device_set:
|
||||
if device_id_to_buffer.get(global_d.id, None) is not None:
|
||||
array = device_id_to_buffer[global_d.id]
|
||||
@ -835,7 +842,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
# Try to find a candidate buffer already on the correct device,
|
||||
# otherwise copy one of them.
|
||||
for buf in candidates_list:
|
||||
if buf.device() == device:
|
||||
if buf.devices() == {device}:
|
||||
bufs.append(buf)
|
||||
break
|
||||
else:
|
||||
|
@ -179,7 +179,7 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.device() == d)]
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
|
@ -228,7 +228,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
device = jax.devices()[-1]
|
||||
x = jit(lambda x: x, device=device)(3.)
|
||||
_check_instance(self, x)
|
||||
self.assertEqual(x.device(), device)
|
||||
self.assertEqual(x.devices(), {device})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
@ -239,42 +239,44 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
if jax.device_count() == 1:
|
||||
raise unittest.SkipTest("Test requires multiple devices")
|
||||
|
||||
system_default_device = jnp.add(1, 1).device()
|
||||
system_default_devices = jnp.add(1, 1).devices()
|
||||
self.assertLen(system_default_devices, 1)
|
||||
system_default_device = list(system_default_devices)[0]
|
||||
test_device = jax.devices()[-1]
|
||||
self.assertNotEqual(system_default_device, test_device)
|
||||
|
||||
f = module(lambda x: x + 1)
|
||||
self.assertEqual(f(1).device(), system_default_device)
|
||||
self.assertEqual(f(1).devices(), system_default_devices)
|
||||
|
||||
with jax.default_device(test_device):
|
||||
self.assertEqual(jnp.add(1, 1).device(), test_device)
|
||||
self.assertEqual(f(1).device(), test_device)
|
||||
self.assertEqual(jnp.add(1, 1).devices(), {test_device})
|
||||
self.assertEqual(f(1).devices(), {test_device})
|
||||
|
||||
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
|
||||
self.assertEqual(f(1).device(), system_default_device)
|
||||
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
|
||||
self.assertEqual(f(1).devices(), system_default_devices)
|
||||
|
||||
with jax.default_device(test_device):
|
||||
# Explicit `device` or `backend` argument to jit overrides default_device
|
||||
self.assertEqual(
|
||||
module(f, device=system_default_device)(1).device(),
|
||||
system_default_device)
|
||||
module(f, device=system_default_device)(1).devices(),
|
||||
system_default_devices)
|
||||
out = module(f, backend="cpu")(1)
|
||||
self.assertEqual(out.device().platform, "cpu")
|
||||
self.assertEqual(next(iter(out.devices())).platform, "cpu")
|
||||
|
||||
# Sticky input device overrides default_device
|
||||
sticky = jax.device_put(1, system_default_device)
|
||||
self.assertEqual(jnp.add(sticky, 1).device(), system_default_device)
|
||||
self.assertEqual(f(sticky).device(), system_default_device)
|
||||
self.assertEqual(jnp.add(sticky, 1).devices(), system_default_devices)
|
||||
self.assertEqual(f(sticky).devices(), system_default_devices)
|
||||
|
||||
# Test nested default_devices
|
||||
with jax.default_device(system_default_device):
|
||||
self.assertEqual(f(1).device(), system_default_device)
|
||||
self.assertEqual(f(1).device(), test_device)
|
||||
self.assertEqual(f(1).devices(), system_default_devices)
|
||||
self.assertEqual(f(1).devices(), {test_device})
|
||||
|
||||
# Test a few more non-default_device calls for good luck
|
||||
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
|
||||
self.assertEqual(f(sticky).device(), system_default_device)
|
||||
self.assertEqual(f(1).device(), system_default_device)
|
||||
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
|
||||
self.assertEqual(f(sticky).devices(), system_default_devices)
|
||||
self.assertEqual(f(1).devices(), system_default_devices)
|
||||
|
||||
# TODO(skye): make this work!
|
||||
def test_jit_default_platform(self):
|
||||
@ -815,8 +817,8 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
result = jitted_f(1.)
|
||||
result_cpu = jitted_f_cpu(1.)
|
||||
self.assertEqual(result.device().platform, jtu.device_under_test())
|
||||
self.assertEqual(result_cpu.device().platform, "cpu")
|
||||
self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test())
|
||||
self.assertEqual(list(result_cpu.devices())[0].platform, "cpu")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
@ -1697,7 +1699,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
u = jax.device_put(y, jax.devices()[0])
|
||||
self.assertArraysAllClose(u, y)
|
||||
self.assertEqual(u.device(), jax.devices()[0])
|
||||
self.assertEqual(u.devices(), {jax.devices()[0]})
|
||||
|
||||
def test_device_put_sharding_tree(self):
|
||||
if jax.device_count() < 2:
|
||||
@ -1830,10 +1832,10 @@ class APITest(jtu.JaxTestCase):
|
||||
d1, d2 = jax.local_devices()[:2]
|
||||
data = self.rng().randn(*shape).astype(np.float32)
|
||||
x = api.device_put(data, device=d1)
|
||||
self.assertEqual(x.device(), d1)
|
||||
self.assertEqual(x.devices(), {d1})
|
||||
|
||||
y = api.device_put(x, device=d2)
|
||||
self.assertEqual(y.device(), d2)
|
||||
self.assertEqual(y.devices(), {d2})
|
||||
|
||||
np.testing.assert_array_equal(data, np.array(y))
|
||||
# Make sure these don't crash
|
||||
@ -1848,11 +1850,11 @@ class APITest(jtu.JaxTestCase):
|
||||
np_arr = np.array([1,2,3])
|
||||
scalar = 1
|
||||
device_arr = jnp.array([1,2,3])
|
||||
assert device_arr.device() is default_device
|
||||
assert device_arr.devices() == {default_device}
|
||||
|
||||
for val in [np_arr, device_arr, scalar]:
|
||||
x = api.device_put(val, device=cpu_device)
|
||||
self.assertEqual(x.device(), cpu_device)
|
||||
self.assertEqual(x.devices(), {cpu_device})
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_jacobian(self):
|
||||
@ -3852,21 +3854,22 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_default_device(self):
|
||||
system_default_device = jnp.zeros(2).device()
|
||||
system_default_devices = jnp.add(1, 1).devices()
|
||||
self.assertLen(system_default_devices, 1)
|
||||
test_device = jax.devices("cpu")[-1]
|
||||
|
||||
# Sanity check creating array using system default device
|
||||
self.assertEqual(jnp.ones(1).device(), system_default_device)
|
||||
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
|
||||
|
||||
# Create array with default_device set
|
||||
with jax.default_device(test_device):
|
||||
# Hits cached primitive path
|
||||
self.assertEqual(jnp.ones(1).device(), test_device)
|
||||
self.assertEqual(jnp.ones(1).devices(), {test_device})
|
||||
# Uncached
|
||||
self.assertEqual(jnp.zeros((1, 2)).device(), test_device)
|
||||
self.assertEqual(jnp.zeros((1, 2)).devices(), {test_device})
|
||||
|
||||
# Test that we can reset to system default device
|
||||
self.assertEqual(jnp.ones(1).device(), system_default_device)
|
||||
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
|
||||
|
||||
def test_dunder_jax_array(self):
|
||||
# https://github.com/google/jax/pull/4725
|
||||
|
@ -77,7 +77,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
x = jax.device_put(np, device)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertEqual(y.device(), device)
|
||||
self.assertEqual(y.devices(), {device})
|
||||
self.assertAllClose(np.astype(x.dtype), y)
|
||||
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
@ -97,11 +97,11 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
device = jax.devices("gpu" if gpu else "cpu")[0]
|
||||
x = jax.device_put(np, device)
|
||||
y = jax.dlpack.from_dlpack(x)
|
||||
self.assertEqual(y.device(), device)
|
||||
self.assertEqual(y.devices(), {device})
|
||||
self.assertAllClose(np.astype(x.dtype), y)
|
||||
# Test we can create multiple arrays
|
||||
z = jax.dlpack.from_dlpack(x)
|
||||
self.assertEqual(z.device(), device)
|
||||
self.assertEqual(z.devices(), {device})
|
||||
self.assertAllClose(np.astype(x.dtype), z)
|
||||
|
||||
|
||||
|
@ -424,7 +424,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
|
||||
y = jax.pmap(jnp.sin)(x)
|
||||
self.assertArraysEqual([a.device() for a in y],
|
||||
self.assertArraysEqual([list(a.devices())[0] for a in y],
|
||||
y.sharding._device_assignment,
|
||||
allow_object_dtype=True)
|
||||
|
||||
@ -550,7 +550,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
for i, j in zip(arr, iter(input_data)):
|
||||
self.assertArraysEqual(i, j)
|
||||
self.assertEqual(i.device(), single_dev[0])
|
||||
self.assertEqual(i.devices(), {single_dev[0]})
|
||||
|
||||
def test_array_shards_committed(self):
|
||||
if jax.device_count() < 2:
|
||||
|
@ -61,12 +61,12 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
def assert_committed_to_device(self, data, device):
|
||||
"""Asserts that the data is committed to the device."""
|
||||
self.assertTrue(data._committed)
|
||||
self.assertEqual(data.device(), device)
|
||||
self.assertEqual(data.devices(), {device})
|
||||
|
||||
def assert_uncommitted_to_device(self, data, device):
|
||||
"""Asserts that the data is on the device but not committed to it."""
|
||||
self.assertFalse(data._committed)
|
||||
self.assertEqual(data.device(), device)
|
||||
self.assertEqual(data.devices(), {device})
|
||||
|
||||
def test_computation_follows_data(self):
|
||||
if jax.device_count() < 5:
|
||||
|
@ -48,7 +48,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
z = fun(x, y)
|
||||
self.assertAllClose(z, z_host, rtol=1e-2)
|
||||
correct_platform = backend if backend else jtu.device_under_test()
|
||||
self.assertEqual(z.device().platform, correct_platform)
|
||||
self.assertEqual(list(z.devices())[0].platform, correct_platform)
|
||||
|
||||
@jtu.sample_product(
|
||||
ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)]
|
||||
@ -72,7 +72,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
z = fun(x, y)
|
||||
self.assertAllClose(z, z_host, rtol=1e-2)
|
||||
correct_platform = outer if outer else jtu.device_under_test()
|
||||
self.assertEqual(z.device().platform, correct_platform)
|
||||
self.assertEqual(list(z.devices())[0].platform, correct_platform)
|
||||
|
||||
@jtu.sample_product(
|
||||
ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'),
|
||||
@ -116,8 +116,8 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
y = npr.uniform(size=(10,10))
|
||||
z = fun(x, y)
|
||||
w = jnp.sin(z)
|
||||
self.assertEqual(z.device().platform, backend)
|
||||
self.assertEqual(w.device().platform, backend)
|
||||
self.assertEqual(list(z.devices())[0].platform, backend)
|
||||
self.assertEqual(list(w.devices())[0].platform, backend)
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
||||
def testJitCpu(self):
|
||||
@ -131,18 +131,18 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
b = x + jnp.ones_like(x)
|
||||
c = x + jnp.eye(2)
|
||||
|
||||
self.assertEqual(a.device(), jax.devices('cpu')[0])
|
||||
self.assertEqual(b.device(), jax.devices('cpu')[0])
|
||||
self.assertEqual(c.device(), jax.devices('cpu')[0])
|
||||
self.assertEqual(a.devices(), {jax.devices('cpu')[0]})
|
||||
self.assertEqual(b.devices(), {jax.devices('cpu')[0]})
|
||||
self.assertEqual(c.devices(), {jax.devices('cpu')[0]})
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
||||
def test_closed_over_values_device_placement(self):
|
||||
# see https://github.com/google/jax/issues/1431
|
||||
def f(): return jnp.add(3., 4.)
|
||||
self.assertNotEqual(jax.jit(f)().device(),
|
||||
jax.devices('cpu')[0])
|
||||
self.assertEqual(jax.jit(f, backend='cpu')().device(),
|
||||
jax.devices('cpu')[0])
|
||||
self.assertNotEqual(jax.jit(f)().devices(),
|
||||
{jax.devices('cpu')[0]})
|
||||
self.assertEqual(jax.jit(f, backend='cpu')().devices(),
|
||||
{jax.devices('cpu')[0]})
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
||||
def test_jit_on_nondefault_backend(self):
|
||||
@ -154,22 +154,22 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(default_dev.platform, "cpu")
|
||||
|
||||
data_on_cpu = jax.device_put(1, device=cpus[0])
|
||||
self.assertEqual(data_on_cpu.device(), cpus[0])
|
||||
self.assertEqual(data_on_cpu.devices(), {cpus[0]})
|
||||
|
||||
def my_sin(x): return jnp.sin(x)
|
||||
# jit without any device spec follows the data
|
||||
result1 = jax.jit(my_sin)(2)
|
||||
self.assertEqual(result1.device(), default_dev)
|
||||
self.assertEqual(result1.devices(), {default_dev})
|
||||
result2 = jax.jit(my_sin)(data_on_cpu)
|
||||
self.assertEqual(result2.device(), cpus[0])
|
||||
self.assertEqual(result2.devices(), {cpus[0]})
|
||||
|
||||
# jit with `device` spec places the data on the specified device
|
||||
result3 = jax.jit(my_sin, device=cpus[0])(2)
|
||||
self.assertEqual(result3.device(), cpus[0])
|
||||
self.assertEqual(result3.devices(), {cpus[0]})
|
||||
|
||||
# jit with `backend` spec places the data on the specified backend
|
||||
result4 = jax.jit(my_sin, backend="cpu")(2)
|
||||
self.assertEqual(result4.device(), cpus[0])
|
||||
self.assertEqual(result4.devices(), {cpus[0]})
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
||||
def test_indexing(self):
|
||||
@ -178,7 +178,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
|
||||
x = jax.device_put(np.ones(2), cpus[0])
|
||||
y = x[0]
|
||||
self.assertEqual(y.device(), cpus[0])
|
||||
self.assertEqual(y.devices(), {cpus[0]})
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
||||
def test_sum(self):
|
||||
@ -187,7 +187,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
|
||||
x = jax.device_put(np.ones(2), cpus[0])
|
||||
y = x.sum()
|
||||
self.assertEqual(y.device(), cpus[0])
|
||||
self.assertEqual(y.devices(), {cpus[0]})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -98,7 +98,7 @@
|
||||
"import jax\n",
|
||||
"key = jax.random.PRNGKey(1701)\n",
|
||||
"arr = jax.random.normal(key, (1000,))\n",
|
||||
"device = arr.device()\n",
|
||||
"device = list(arr.devices())[0]\n",
|
||||
"print(f\"JAX device type: {device}\")\n",
|
||||
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
|
||||
]
|
||||
|
@ -1,24 +1,10 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "JAX Colab GPU Test",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_gpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
@ -27,8 +13,8 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WkadOyTDCAWD",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "WkadOyTDCAWD"
|
||||
},
|
||||
"source": [
|
||||
"# JAX Colab GPU Test\n",
|
||||
@ -38,15 +24,27 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "_tKNrbqqBHwu",
|
||||
"colab_type": "code",
|
||||
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 68
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "_tKNrbqqBHwu",
|
||||
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"gpu-t4-s-kbefivsjoreh\n",
|
||||
"0.1.64\n",
|
||||
"0.1.45\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import jaxlib\n",
|
||||
@ -54,25 +52,13 @@
|
||||
"!cat /var/colab/hostname\n",
|
||||
"print(jax.__version__)\n",
|
||||
"print(jaxlib.__version__)"
|
||||
],
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"gpu-t4-s-kbefivsjoreh\n",
|
||||
"0.1.64\n",
|
||||
"0.1.45\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "oqEG21rADO1F",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "oqEG21rADO1F"
|
||||
},
|
||||
"source": [
|
||||
"## Confirm Device"
|
||||
@ -80,39 +66,39 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "8BwzMYhKGQj6",
|
||||
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 34
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "8BwzMYhKGQj6",
|
||||
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"JAX device type: gpu:0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"key = jax.random.PRNGKey(1701)\n",
|
||||
"arr = jax.random.normal(key, (1000,))\n",
|
||||
"device = arr.device()\n",
|
||||
"device = list(arr.devices())[0]\n",
|
||||
"print(f\"JAX device type: {device}\")\n",
|
||||
"assert device.platform == \"gpu\", \"unexpected JAX device type\""
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"JAX device type: gpu:0\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "z0FUY9yUC4k1",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "z0FUY9yUC4k1"
|
||||
},
|
||||
"source": [
|
||||
"## Matrix Multiplication"
|
||||
@ -120,15 +106,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "eXn8GUl6CG5N",
|
||||
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 34
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "eXn8GUl6CG5N",
|
||||
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.0216676\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import numpy as np\n",
|
||||
@ -138,23 +134,13 @@
|
||||
"x = jax.random.normal(key, (3000, 3000))\n",
|
||||
"result = jax.numpy.dot(x, x.T).mean()\n",
|
||||
"print(result)"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.0216676\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "0zTA2Q19DW4G",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "0zTA2Q19DW4G"
|
||||
},
|
||||
"source": [
|
||||
"## Linear Algebra"
|
||||
@ -162,15 +148,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"id": "uW9j84_UDYof",
|
||||
"colab_type": "code",
|
||||
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 51
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "uW9j84_UDYof",
|
||||
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
|
||||
" 2.866489 1.8229384 1.5478926]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax.random as rand\n",
|
||||
@ -184,24 +181,13 @@
|
||||
"assert u.shape == (N, N)\n",
|
||||
"assert vt.shape == (M, M)\n",
|
||||
"print(s)"
|
||||
],
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
|
||||
" 2.866489 1.8229384 1.5478926]\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "jCyKUn4-DCXn",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "jCyKUn4-DCXn"
|
||||
},
|
||||
"source": [
|
||||
"## XLA Compilation"
|
||||
@ -209,15 +195,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "2GOn_HhDPuEn",
|
||||
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 51
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "2GOn_HhDPuEn",
|
||||
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
|
||||
" 0.13093245]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def selu(x, alpha=1.67, lmbda=1.05):\n",
|
||||
@ -225,18 +222,21 @@
|
||||
"x = jax.random.normal(key, (5000,))\n",
|
||||
"result = selu(x).block_until_ready()\n",
|
||||
"print(result)"
|
||||
],
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
|
||||
" 0.13093245]\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "JAX Colab GPU Test",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
@ -2539,7 +2539,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertFalse(a._committed)
|
||||
out = f(a, a)
|
||||
self.assertFalse(out._committed)
|
||||
self.assertEqual(out.device(), jax.devices()[0])
|
||||
self.assertEqual(out.devices(), {jax.devices()[0]})
|
||||
self.assertArraysEqual(out, a * 2)
|
||||
|
||||
with jax.default_device(jax.devices()[1]):
|
||||
@ -2547,7 +2547,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertFalse(b._committed)
|
||||
out2 = f(b, b)
|
||||
self.assertFalse(out2._committed)
|
||||
self.assertEqual(out2.device(), jax.devices()[1])
|
||||
self.assertEqual(out2.devices(), {jax.devices()[1]})
|
||||
self.assertArraysEqual(out2, b * 2)
|
||||
|
||||
def test_pjit_with_static_argnames(self):
|
||||
@ -2590,7 +2590,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
if jax.device_count() <= 1:
|
||||
self.skipTest('Test requires more >1 device.')
|
||||
|
||||
system_default_device = jnp.add(1, 1).device()
|
||||
system_default_device = list(jnp.add(1, 1).devices())[0]
|
||||
test_device = jax.devices()[-1]
|
||||
|
||||
f = pjit(lambda x: x + 1)
|
||||
@ -2733,7 +2733,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return x @ x.T
|
||||
|
||||
def _check(out, expected_device, expected_out):
|
||||
self.assertEqual(out.device(), expected_device)
|
||||
self.assertEqual(out.devices(), {expected_device})
|
||||
self.assertLen(out.sharding.device_set, 1)
|
||||
self.assertArraysEqual(out, expected_out @ expected_out.T)
|
||||
|
||||
@ -2776,14 +2776,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
expected_device = jax.devices()[2]
|
||||
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
|
||||
|
||||
self.assertEqual(final_out.device(), expected_device)
|
||||
self.assertEqual(final_out.devices(), {expected_device})
|
||||
self.assertLen(final_out.sharding.device_set, 1)
|
||||
self.assertArraysEqual(final_out, inp * 6)
|
||||
|
||||
@jtu.run_on_devices("tpu")
|
||||
def test_pjit_with_backend_arg(self):
|
||||
def _check(out, expected_device, expected_out):
|
||||
self.assertEqual(out.device(), expected_device)
|
||||
self.assertEqual(out.devices(), {expected_device})
|
||||
self.assertLen(out.sharding.device_set, 1)
|
||||
self.assertArraysEqual(out, expected_out)
|
||||
|
||||
@ -3403,7 +3403,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
y = jax.device_put(x, jax.devices()[1])
|
||||
out2 = jax.jit(lambda x: x)(y)
|
||||
self.assertIsInstance(out2.sharding, SingleDeviceSharding)
|
||||
self.assertEqual(out2.device(), jax.devices()[1])
|
||||
self.assertEqual(out2.devices(), {jax.devices()[1]})
|
||||
|
||||
out3 = jax.jit(lambda x: x * 2)(x)
|
||||
self.assertIsInstance(out3.sharding, SingleDeviceSharding)
|
||||
@ -3411,7 +3411,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out4 = jax.jit(lambda x: x * 3,
|
||||
out_shardings=SingleDeviceSharding(jax.devices()[1]))(x)
|
||||
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
||||
self.assertEqual(out4.device(), jax.devices()[1])
|
||||
self.assertEqual(out4.devices(), {jax.devices()[1]})
|
||||
|
||||
def test_none_out_sharding(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
@ -3449,7 +3449,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
arr4 = jax.device_put(jnp.arange(8), jax.devices()[1])
|
||||
out4 = jnp.copy(arr4)
|
||||
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
||||
self.assertEqual(out4.device(), jax.devices()[1])
|
||||
self.assertEqual(out4.devices(), {jax.devices()[1]})
|
||||
|
||||
def test_get_indices_cache(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
@ -3506,12 +3506,12 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# Fill up the to_gspmd_sharding cache so that the next jit will miss it.
|
||||
out = jax.jit(identity,
|
||||
in_shardings=SingleDeviceSharding(jax.devices()[0]))(np_inp)
|
||||
self.assertEqual(out.device(), jax.devices()[0])
|
||||
self.assertEqual(out.devices(), {jax.devices()[0]})
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
out2 = jax.jit(identity, device=jax.devices()[0])(
|
||||
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
|
||||
self.assertEqual(out2.device(), jax.devices()[0])
|
||||
self.assertEqual(out2.devices(), {jax.devices()[0]})
|
||||
self.assertArraysEqual(out2, np_inp)
|
||||
|
||||
def test_jit_submhlo_cached(self):
|
||||
|
@ -147,12 +147,12 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
view = jnp.array(buf, copy=False)
|
||||
self.assertArraysEqual(sda[-1], view)
|
||||
self.assertEqual(buf.device(), view.device())
|
||||
self.assertSetEqual(buf.devices(), view.devices())
|
||||
self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer())
|
||||
|
||||
copy = jnp.array(buf, copy=True)
|
||||
self.assertArraysEqual(sda[-1], copy)
|
||||
self.assertEqual(buf.device(), copy.device())
|
||||
self.assertSetEqual(buf.devices(), copy.devices())
|
||||
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
|
||||
|
||||
def _getMeshShape(self, device_mesh_shape):
|
||||
@ -869,7 +869,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# test that we can handle device movement on dispatch
|
||||
bufs = y._arrays[::-1]
|
||||
sharding = jax.sharding.PmapSharding(
|
||||
[b.device() for b in bufs], y.sharding.sharding_spec)
|
||||
[list(b.devices())[0] for b in bufs], y.sharding.sharding_spec)
|
||||
y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs)
|
||||
z = f(y)
|
||||
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
|
||||
@ -2769,7 +2769,7 @@ class ArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(s.replica_id, 0)
|
||||
buffers = getattr(y, '_arrays')
|
||||
self.assertEqual(len(buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
|
||||
self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices)))
|
||||
self.assertArraysEqual(y, jnp.stack(x))
|
||||
|
||||
def test_device_put_sharded_pytree(self):
|
||||
@ -2781,12 +2781,12 @@ class ArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y1, array.ArrayImpl)
|
||||
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
|
||||
y1_buffers = getattr(y1, '_arrays')
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
|
||||
self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices)))
|
||||
|
||||
self.assertIsInstance(y2, array.ArrayImpl)
|
||||
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
|
||||
y2_buffers = getattr(y2, '_arrays')
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
|
||||
self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices)))
|
||||
|
||||
def test_device_put_replicated(self):
|
||||
devices = jax.local_devices()
|
||||
@ -2796,7 +2796,7 @@ class ArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
buffers = getattr(y, '_arrays')
|
||||
self.assertEqual(len(buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
|
||||
self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices)))
|
||||
self.assertArraysEqual(y, np.stack([x for _ in devices]))
|
||||
|
||||
def test_device_put_replicated_pytree(self):
|
||||
@ -2809,13 +2809,13 @@ class ArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y1, array.ArrayImpl)
|
||||
y1_buffers = getattr(y1, '_arrays')
|
||||
self.assertEqual(len(y1_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
|
||||
self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices)))
|
||||
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
|
||||
|
||||
self.assertIsInstance(y2, array.ArrayImpl)
|
||||
y2_buffers = getattr(y2, '_arrays')
|
||||
self.assertEqual(len(y2_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
|
||||
self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices)))
|
||||
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
|
||||
|
||||
def test_repr(self):
|
||||
@ -3127,8 +3127,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
self.skipTest('Test requires >= 2 devices.')
|
||||
|
||||
def amap(f, xs):
|
||||
ys = [f(jax.device_put(x, x.device())) for x in xs]
|
||||
return jax.device_put_sharded(ys, [y.device() for y in ys])
|
||||
ys = [f(jax.device_put(x, list(x.devices())[0])) for x in xs]
|
||||
return jax.device_put_sharded(ys, [list(y.devices())[0] for y in ys])
|
||||
|
||||
# leading axis is batch dim (i.e. mapped/parallel dim), of size 2
|
||||
x = jnp.array([[1., 0., 0.],
|
||||
|
@ -1026,7 +1026,8 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
|
||||
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
|
||||
self.assertEqual(key.device(), key._base_array.device())
|
||||
with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
|
||||
self.assertEqual(key.device(), key._base_array.device())
|
||||
self.assertEqual(key.devices(), key._base_array.devices())
|
||||
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
|
||||
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)
|
||||
|
Loading…
x
Reference in New Issue
Block a user