Deprecate the device() method of JAX arrays

This commit is contained in:
Jake VanderPlas 2023-11-29 16:52:09 -08:00
parent 4de07b3f62
commit 97beb01c43
14 changed files with 209 additions and 194 deletions

View File

@ -32,7 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
but the previous outputs can be recovered this way: but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)` * `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]` * `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
* The `device()` method of JAX arrays deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.
## jaxlib 0.4.21 ## jaxlib 0.4.21

View File

@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
platforms are available in priority order). platforms are available in priority order).
>>> from jax import numpy as jnp >>> from jax import numpy as jnp
>>> print(jnp.ones(3).device()) # doctest: +SKIP >>> print(jnp.ones(3).devices()) # doctest: +SKIP
cuda:0 {CudaDevice(id=0)}
Computations involving uncommitted data are performed on the default Computations involving uncommitted data are performed on the default
device and the results are uncommitted on the default device. device and the results are uncommitted on the default device.
@ -355,8 +355,9 @@ with a ``device`` parameter, in which case the data becomes **committed** to the
>>> import jax >>> import jax
>>> from jax import device_put >>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP >>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP
cuda:2 >>> print(arr.devices()) # doctest: +SKIP
{CudaDevice(id=2)}
Computations involving some committed inputs will happen on the Computations involving some committed inputs will happen on the
committed device and the result will be committed on the committed device and the result will be committed on the

View File

@ -51,6 +51,11 @@ Device = xc.Device
Index = tuple[slice, ...] Index = tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this. PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
assert len(a.devices()) == 1
return next(iter(a.devices()))
class Shard: class Shard:
"""A single data shard of an Array. """A single data shard of an Array.
@ -128,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
di_map = _cached_index_calc(s, shape) di_map = _cached_index_calc(s, shape)
copy_plan = [] copy_plan = []
for a in arrays: for a in arrays:
ind = di_map.get(a.device(), None) ind = di_map.get(_get_device(a), None)
if ind is not None: if ind is not None:
copy_plan.append((ind, a)) copy_plan.append((ind, a))
return copy_plan return copy_plan
@ -183,7 +188,7 @@ class ArrayImpl(basearray.Array):
"Input buffers to `Array` must have matching dtypes. " "Input buffers to `Array` must have matching dtypes. "
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}") f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")
device_id_to_buffer = {db.device().id: db for db in self._arrays} device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev): if len(self._arrays) != len(addressable_dev):
@ -324,7 +329,7 @@ class ArrayImpl(basearray.Array):
if arr_idx is not None: if arr_idx is not None:
a = self._arrays[arr_idx] a = self._arrays[arr_idx]
return ArrayImpl( return ArrayImpl(
a.aval, SingleDeviceSharding(a.device()), [a], committed=False, a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False,
_skip_checks=True) _skip_checks=True)
return lax_numpy._rewriting_take(self, idx) return lax_numpy._rewriting_take(self, idx)
else: else:
@ -400,7 +405,7 @@ class ArrayImpl(basearray.Array):
return DLDeviceType.kDLCPU, 0 return DLDeviceType.kDLCPU, 0
elif self.platform() == "gpu": elif self.platform() == "gpu":
platform_version = self.device().client.platform_version platform_version = _get_device(self).client.platform_version
if "cuda" in platform_version: if "cuda" in platform_version:
dl_device_type = DLDeviceType.kDLCUDA dl_device_type = DLDeviceType.kDLCUDA
elif "rocm" in platform_version: elif "rocm" in platform_version:
@ -409,7 +414,7 @@ class ArrayImpl(basearray.Array):
raise ValueError("Unknown GPU platform for __dlpack__: " raise ValueError("Unknown GPU platform for __dlpack__: "
f"{platform_version}") f"{platform_version}")
local_hardware_id = self.device().local_hardware_id local_hardware_id = _get_device(self).local_hardware_id
if local_hardware_id is None: if local_hardware_id is None:
raise ValueError("Couldn't get local_hardware_id for __dlpack__") raise ValueError("Couldn't get local_hardware_id for __dlpack__")
@ -451,6 +456,8 @@ class ArrayImpl(basearray.Array):
# TODO(yashkatariya): Remove this method when everyone is using devices(). # TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device: def device(self) -> Device:
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
DeprecationWarning, stacklevel=2)
self._check_if_deleted() self._check_if_deleted()
device_set = self.sharding.device_set device_set = self.sharding.device_set
if len(device_set) == 1: if len(device_set) == 1:
@ -499,7 +506,7 @@ class ArrayImpl(basearray.Array):
self._check_if_deleted() self._check_if_deleted()
out = [] out = []
for a in self._arrays: for a in self._arrays:
out.append(Shard(a.device(), self.sharding, self.shape, a)) out.append(Shard(_get_device(a), self.sharding, self.shape, a))
return out return out
@property @property
@ -514,7 +521,7 @@ class ArrayImpl(basearray.Array):
return self.addressable_shards return self.addressable_shards
out = [] out = []
device_id_to_buffer = {a.device().id: a for a in self._arrays} device_id_to_buffer = {_get_device(a).id: a for a in self._arrays}
for global_d in self.sharding.device_set: for global_d in self.sharding.device_set:
if device_id_to_buffer.get(global_d.id, None) is not None: if device_id_to_buffer.get(global_d.id, None) is not None:
array = device_id_to_buffer[global_d.id] array = device_id_to_buffer[global_d.id]
@ -835,7 +842,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
# Try to find a candidate buffer already on the correct device, # Try to find a candidate buffer already on the correct device,
# otherwise copy one of them. # otherwise copy one of them.
for buf in candidates_list: for buf in candidates_list:
if buf.device() == device: if buf.devices() == {device}:
bufs.append(buf) bufs.append(buf)
break break
else: else:

View File

@ -179,7 +179,7 @@ def batched_device_put(aval: core.ShapedArray,
bufs = [x for x, d in safe_zip(xs, devices) bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)] x.devices() == {d})]
if len(bufs) == len(xs): if len(bufs) == len(xs):
return array.ArrayImpl( return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True) aval, sharding, bufs, committed=committed, _skip_checks=True)

View File

@ -228,7 +228,7 @@ class JitTest(jtu.BufferDonationTestCase):
device = jax.devices()[-1] device = jax.devices()[-1]
x = jit(lambda x: x, device=device)(3.) x = jit(lambda x: x, device=device)(3.)
_check_instance(self, x) _check_instance(self, x)
self.assertEqual(x.device(), device) self.assertEqual(x.devices(), {device})
@parameterized.named_parameters( @parameterized.named_parameters(
('jit', jax.jit), ('jit', jax.jit),
@ -239,42 +239,44 @@ class JitTest(jtu.BufferDonationTestCase):
if jax.device_count() == 1: if jax.device_count() == 1:
raise unittest.SkipTest("Test requires multiple devices") raise unittest.SkipTest("Test requires multiple devices")
system_default_device = jnp.add(1, 1).device() system_default_devices = jnp.add(1, 1).devices()
self.assertLen(system_default_devices, 1)
system_default_device = list(system_default_devices)[0]
test_device = jax.devices()[-1] test_device = jax.devices()[-1]
self.assertNotEqual(system_default_device, test_device) self.assertNotEqual(system_default_device, test_device)
f = module(lambda x: x + 1) f = module(lambda x: x + 1)
self.assertEqual(f(1).device(), system_default_device) self.assertEqual(f(1).devices(), system_default_devices)
with jax.default_device(test_device): with jax.default_device(test_device):
self.assertEqual(jnp.add(1, 1).device(), test_device) self.assertEqual(jnp.add(1, 1).devices(), {test_device})
self.assertEqual(f(1).device(), test_device) self.assertEqual(f(1).devices(), {test_device})
self.assertEqual(jnp.add(1, 1).device(), system_default_device) self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
self.assertEqual(f(1).device(), system_default_device) self.assertEqual(f(1).devices(), system_default_devices)
with jax.default_device(test_device): with jax.default_device(test_device):
# Explicit `device` or `backend` argument to jit overrides default_device # Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual( self.assertEqual(
module(f, device=system_default_device)(1).device(), module(f, device=system_default_device)(1).devices(),
system_default_device) system_default_devices)
out = module(f, backend="cpu")(1) out = module(f, backend="cpu")(1)
self.assertEqual(out.device().platform, "cpu") self.assertEqual(next(iter(out.devices())).platform, "cpu")
# Sticky input device overrides default_device # Sticky input device overrides default_device
sticky = jax.device_put(1, system_default_device) sticky = jax.device_put(1, system_default_device)
self.assertEqual(jnp.add(sticky, 1).device(), system_default_device) self.assertEqual(jnp.add(sticky, 1).devices(), system_default_devices)
self.assertEqual(f(sticky).device(), system_default_device) self.assertEqual(f(sticky).devices(), system_default_devices)
# Test nested default_devices # Test nested default_devices
with jax.default_device(system_default_device): with jax.default_device(system_default_device):
self.assertEqual(f(1).device(), system_default_device) self.assertEqual(f(1).devices(), system_default_devices)
self.assertEqual(f(1).device(), test_device) self.assertEqual(f(1).devices(), {test_device})
# Test a few more non-default_device calls for good luck # Test a few more non-default_device calls for good luck
self.assertEqual(jnp.add(1, 1).device(), system_default_device) self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
self.assertEqual(f(sticky).device(), system_default_device) self.assertEqual(f(sticky).devices(), system_default_devices)
self.assertEqual(f(1).device(), system_default_device) self.assertEqual(f(1).devices(), system_default_devices)
# TODO(skye): make this work! # TODO(skye): make this work!
def test_jit_default_platform(self): def test_jit_default_platform(self):
@ -815,8 +817,8 @@ class JitTest(jtu.BufferDonationTestCase):
result = jitted_f(1.) result = jitted_f(1.)
result_cpu = jitted_f_cpu(1.) result_cpu = jitted_f_cpu(1.)
self.assertEqual(result.device().platform, jtu.device_under_test()) self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test())
self.assertEqual(result_cpu.device().platform, "cpu") self.assertEqual(list(result_cpu.devices())[0].platform, "cpu")
@parameterized.named_parameters( @parameterized.named_parameters(
('jit', jax.jit), ('jit', jax.jit),
@ -1697,7 +1699,7 @@ class APITest(jtu.JaxTestCase):
u = jax.device_put(y, jax.devices()[0]) u = jax.device_put(y, jax.devices()[0])
self.assertArraysAllClose(u, y) self.assertArraysAllClose(u, y)
self.assertEqual(u.device(), jax.devices()[0]) self.assertEqual(u.devices(), {jax.devices()[0]})
def test_device_put_sharding_tree(self): def test_device_put_sharding_tree(self):
if jax.device_count() < 2: if jax.device_count() < 2:
@ -1830,10 +1832,10 @@ class APITest(jtu.JaxTestCase):
d1, d2 = jax.local_devices()[:2] d1, d2 = jax.local_devices()[:2]
data = self.rng().randn(*shape).astype(np.float32) data = self.rng().randn(*shape).astype(np.float32)
x = api.device_put(data, device=d1) x = api.device_put(data, device=d1)
self.assertEqual(x.device(), d1) self.assertEqual(x.devices(), {d1})
y = api.device_put(x, device=d2) y = api.device_put(x, device=d2)
self.assertEqual(y.device(), d2) self.assertEqual(y.devices(), {d2})
np.testing.assert_array_equal(data, np.array(y)) np.testing.assert_array_equal(data, np.array(y))
# Make sure these don't crash # Make sure these don't crash
@ -1848,11 +1850,11 @@ class APITest(jtu.JaxTestCase):
np_arr = np.array([1,2,3]) np_arr = np.array([1,2,3])
scalar = 1 scalar = 1
device_arr = jnp.array([1,2,3]) device_arr = jnp.array([1,2,3])
assert device_arr.device() is default_device assert device_arr.devices() == {default_device}
for val in [np_arr, device_arr, scalar]: for val in [np_arr, device_arr, scalar]:
x = api.device_put(val, device=cpu_device) x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device(), cpu_device) self.assertEqual(x.devices(), {cpu_device})
@jax.default_matmul_precision("float32") @jax.default_matmul_precision("float32")
def test_jacobian(self): def test_jacobian(self):
@ -3852,21 +3854,22 @@ class APITest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu") @jtu.skip_on_devices("cpu")
def test_default_device(self): def test_default_device(self):
system_default_device = jnp.zeros(2).device() system_default_devices = jnp.add(1, 1).devices()
self.assertLen(system_default_devices, 1)
test_device = jax.devices("cpu")[-1] test_device = jax.devices("cpu")[-1]
# Sanity check creating array using system default device # Sanity check creating array using system default device
self.assertEqual(jnp.ones(1).device(), system_default_device) self.assertEqual(jnp.ones(1).devices(), system_default_devices)
# Create array with default_device set # Create array with default_device set
with jax.default_device(test_device): with jax.default_device(test_device):
# Hits cached primitive path # Hits cached primitive path
self.assertEqual(jnp.ones(1).device(), test_device) self.assertEqual(jnp.ones(1).devices(), {test_device})
# Uncached # Uncached
self.assertEqual(jnp.zeros((1, 2)).device(), test_device) self.assertEqual(jnp.zeros((1, 2)).devices(), {test_device})
# Test that we can reset to system default device # Test that we can reset to system default device
self.assertEqual(jnp.ones(1).device(), system_default_device) self.assertEqual(jnp.ones(1).devices(), system_default_devices)
def test_dunder_jax_array(self): def test_dunder_jax_array(self):
# https://github.com/google/jax/pull/4725 # https://github.com/google/jax/pull/4725

View File

@ -77,7 +77,7 @@ class DLPackTest(jtu.JaxTestCase):
x = jax.device_put(np, device) x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x) dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack) y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.device(), device) self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y) self.assertAllClose(np.astype(x.dtype), y)
self.assertRaisesRegex(RuntimeError, self.assertRaisesRegex(RuntimeError,
@ -97,11 +97,11 @@ class DLPackTest(jtu.JaxTestCase):
device = jax.devices("gpu" if gpu else "cpu")[0] device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device) x = jax.device_put(np, device)
y = jax.dlpack.from_dlpack(x) y = jax.dlpack.from_dlpack(x)
self.assertEqual(y.device(), device) self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y) self.assertAllClose(np.astype(x.dtype), y)
# Test we can create multiple arrays # Test we can create multiple arrays
z = jax.dlpack.from_dlpack(x) z = jax.dlpack.from_dlpack(x)
self.assertEqual(z.device(), device) self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z) self.assertAllClose(np.astype(x.dtype), z)

View File

@ -424,7 +424,7 @@ class JaxArrayTest(jtu.JaxTestCase):
x = jnp.array([[1., 0., 0.], [0., 2., 3.]]) x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x) y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y], self.assertArraysEqual([list(a.devices())[0] for a in y],
y.sharding._device_assignment, y.sharding._device_assignment,
allow_object_dtype=True) allow_object_dtype=True)
@ -550,7 +550,7 @@ class JaxArrayTest(jtu.JaxTestCase):
for i, j in zip(arr, iter(input_data)): for i, j in zip(arr, iter(input_data)):
self.assertArraysEqual(i, j) self.assertArraysEqual(i, j)
self.assertEqual(i.device(), single_dev[0]) self.assertEqual(i.devices(), {single_dev[0]})
def test_array_shards_committed(self): def test_array_shards_committed(self):
if jax.device_count() < 2: if jax.device_count() < 2:

View File

@ -61,12 +61,12 @@ class MultiDeviceTest(jtu.JaxTestCase):
def assert_committed_to_device(self, data, device): def assert_committed_to_device(self, data, device):
"""Asserts that the data is committed to the device.""" """Asserts that the data is committed to the device."""
self.assertTrue(data._committed) self.assertTrue(data._committed)
self.assertEqual(data.device(), device) self.assertEqual(data.devices(), {device})
def assert_uncommitted_to_device(self, data, device): def assert_uncommitted_to_device(self, data, device):
"""Asserts that the data is on the device but not committed to it.""" """Asserts that the data is on the device but not committed to it."""
self.assertFalse(data._committed) self.assertFalse(data._committed)
self.assertEqual(data.device(), device) self.assertEqual(data.devices(), {device})
def test_computation_follows_data(self): def test_computation_follows_data(self):
if jax.device_count() < 5: if jax.device_count() < 5:

View File

@ -48,7 +48,7 @@ class MultiBackendTest(jtu.JaxTestCase):
z = fun(x, y) z = fun(x, y)
self.assertAllClose(z, z_host, rtol=1e-2) self.assertAllClose(z, z_host, rtol=1e-2)
correct_platform = backend if backend else jtu.device_under_test() correct_platform = backend if backend else jtu.device_under_test()
self.assertEqual(z.device().platform, correct_platform) self.assertEqual(list(z.devices())[0].platform, correct_platform)
@jtu.sample_product( @jtu.sample_product(
ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)] ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)]
@ -72,7 +72,7 @@ class MultiBackendTest(jtu.JaxTestCase):
z = fun(x, y) z = fun(x, y)
self.assertAllClose(z, z_host, rtol=1e-2) self.assertAllClose(z, z_host, rtol=1e-2)
correct_platform = outer if outer else jtu.device_under_test() correct_platform = outer if outer else jtu.device_under_test()
self.assertEqual(z.device().platform, correct_platform) self.assertEqual(list(z.devices())[0].platform, correct_platform)
@jtu.sample_product( @jtu.sample_product(
ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'), ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'),
@ -116,8 +116,8 @@ class MultiBackendTest(jtu.JaxTestCase):
y = npr.uniform(size=(10,10)) y = npr.uniform(size=(10,10))
z = fun(x, y) z = fun(x, y)
w = jnp.sin(z) w = jnp.sin(z)
self.assertEqual(z.device().platform, backend) self.assertEqual(list(z.devices())[0].platform, backend)
self.assertEqual(w.device().platform, backend) self.assertEqual(list(w.devices())[0].platform, backend)
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
def testJitCpu(self): def testJitCpu(self):
@ -131,18 +131,18 @@ class MultiBackendTest(jtu.JaxTestCase):
b = x + jnp.ones_like(x) b = x + jnp.ones_like(x)
c = x + jnp.eye(2) c = x + jnp.eye(2)
self.assertEqual(a.device(), jax.devices('cpu')[0]) self.assertEqual(a.devices(), {jax.devices('cpu')[0]})
self.assertEqual(b.device(), jax.devices('cpu')[0]) self.assertEqual(b.devices(), {jax.devices('cpu')[0]})
self.assertEqual(c.device(), jax.devices('cpu')[0]) self.assertEqual(c.devices(), {jax.devices('cpu')[0]})
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
def test_closed_over_values_device_placement(self): def test_closed_over_values_device_placement(self):
# see https://github.com/google/jax/issues/1431 # see https://github.com/google/jax/issues/1431
def f(): return jnp.add(3., 4.) def f(): return jnp.add(3., 4.)
self.assertNotEqual(jax.jit(f)().device(), self.assertNotEqual(jax.jit(f)().devices(),
jax.devices('cpu')[0]) {jax.devices('cpu')[0]})
self.assertEqual(jax.jit(f, backend='cpu')().device(), self.assertEqual(jax.jit(f, backend='cpu')().devices(),
jax.devices('cpu')[0]) {jax.devices('cpu')[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_jit_on_nondefault_backend(self): def test_jit_on_nondefault_backend(self):
@ -154,22 +154,22 @@ class MultiBackendTest(jtu.JaxTestCase):
self.assertNotEqual(default_dev.platform, "cpu") self.assertNotEqual(default_dev.platform, "cpu")
data_on_cpu = jax.device_put(1, device=cpus[0]) data_on_cpu = jax.device_put(1, device=cpus[0])
self.assertEqual(data_on_cpu.device(), cpus[0]) self.assertEqual(data_on_cpu.devices(), {cpus[0]})
def my_sin(x): return jnp.sin(x) def my_sin(x): return jnp.sin(x)
# jit without any device spec follows the data # jit without any device spec follows the data
result1 = jax.jit(my_sin)(2) result1 = jax.jit(my_sin)(2)
self.assertEqual(result1.device(), default_dev) self.assertEqual(result1.devices(), {default_dev})
result2 = jax.jit(my_sin)(data_on_cpu) result2 = jax.jit(my_sin)(data_on_cpu)
self.assertEqual(result2.device(), cpus[0]) self.assertEqual(result2.devices(), {cpus[0]})
# jit with `device` spec places the data on the specified device # jit with `device` spec places the data on the specified device
result3 = jax.jit(my_sin, device=cpus[0])(2) result3 = jax.jit(my_sin, device=cpus[0])(2)
self.assertEqual(result3.device(), cpus[0]) self.assertEqual(result3.devices(), {cpus[0]})
# jit with `backend` spec places the data on the specified backend # jit with `backend` spec places the data on the specified backend
result4 = jax.jit(my_sin, backend="cpu")(2) result4 = jax.jit(my_sin, backend="cpu")(2)
self.assertEqual(result4.device(), cpus[0]) self.assertEqual(result4.devices(), {cpus[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_indexing(self): def test_indexing(self):
@ -178,7 +178,7 @@ class MultiBackendTest(jtu.JaxTestCase):
x = jax.device_put(np.ones(2), cpus[0]) x = jax.device_put(np.ones(2), cpus[0])
y = x[0] y = x[0]
self.assertEqual(y.device(), cpus[0]) self.assertEqual(y.devices(), {cpus[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_sum(self): def test_sum(self):
@ -187,7 +187,7 @@ class MultiBackendTest(jtu.JaxTestCase):
x = jax.device_put(np.ones(2), cpus[0]) x = jax.device_put(np.ones(2), cpus[0])
y = x.sum() y = x.sum()
self.assertEqual(y.device(), cpus[0]) self.assertEqual(y.devices(), {cpus[0]})
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -98,7 +98,7 @@
"import jax\n", "import jax\n",
"key = jax.random.PRNGKey(1701)\n", "key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n", "arr = jax.random.normal(key, (1000,))\n",
"device = arr.device()\n", "device = list(arr.devices())[0]\n",
"print(f\"JAX device type: {device}\")\n", "print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
] ]

View File

@ -1,24 +1,10 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX Colab GPU Test",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "view-in-github", "colab_type": "text",
"colab_type": "text" "id": "view-in-github"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_gpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_gpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@ -27,8 +13,8 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "WkadOyTDCAWD", "colab_type": "text",
"colab_type": "text" "id": "WkadOyTDCAWD"
}, },
"source": [ "source": [
"# JAX Colab GPU Test\n", "# JAX Colab GPU Test\n",
@ -38,15 +24,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1,
"metadata": { "metadata": {
"id": "_tKNrbqqBHwu",
"colab_type": "code",
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33",
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
"height": 68 "height": 68
} },
"colab_type": "code",
"id": "_tKNrbqqBHwu",
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpu-t4-s-kbefivsjoreh\n",
"0.1.64\n",
"0.1.45\n"
]
}
],
"source": [ "source": [
"import jax\n", "import jax\n",
"import jaxlib\n", "import jaxlib\n",
@ -54,25 +52,13 @@
"!cat /var/colab/hostname\n", "!cat /var/colab/hostname\n",
"print(jax.__version__)\n", "print(jax.__version__)\n",
"print(jaxlib.__version__)" "print(jaxlib.__version__)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"gpu-t4-s-kbefivsjoreh\n",
"0.1.64\n",
"0.1.45\n"
],
"name": "stdout"
}
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "oqEG21rADO1F", "colab_type": "text",
"colab_type": "text" "id": "oqEG21rADO1F"
}, },
"source": [ "source": [
"## Confirm Device" "## Confirm Device"
@ -80,39 +66,39 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2,
"metadata": { "metadata": {
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f",
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
"height": 34 "height": 34
} },
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX device type: gpu:0\n"
]
}
],
"source": [ "source": [
"import jax\n", "import jax\n",
"key = jax.random.PRNGKey(1701)\n", "key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n", "arr = jax.random.normal(key, (1000,))\n",
"device = arr.device()\n", "device = list(arr.devices())[0]\n",
"print(f\"JAX device type: {device}\")\n", "print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"gpu\", \"unexpected JAX device type\"" "assert device.platform == \"gpu\", \"unexpected JAX device type\""
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"JAX device type: gpu:0\n"
],
"name": "stdout"
}
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "z0FUY9yUC4k1", "colab_type": "text",
"colab_type": "text" "id": "z0FUY9yUC4k1"
}, },
"source": [ "source": [
"## Matrix Multiplication" "## Matrix Multiplication"
@ -120,15 +106,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3,
"metadata": { "metadata": {
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9",
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
"height": 34 "height": 34
} },
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0216676\n"
]
}
],
"source": [ "source": [
"import jax\n", "import jax\n",
"import numpy as np\n", "import numpy as np\n",
@ -138,23 +134,13 @@
"x = jax.random.normal(key, (3000, 3000))\n", "x = jax.random.normal(key, (3000, 3000))\n",
"result = jax.numpy.dot(x, x.T).mean()\n", "result = jax.numpy.dot(x, x.T).mean()\n",
"print(result)" "print(result)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"1.0216676\n"
],
"name": "stdout"
}
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "0zTA2Q19DW4G", "colab_type": "text",
"colab_type": "text" "id": "0zTA2Q19DW4G"
}, },
"source": [ "source": [
"## Linear Algebra" "## Linear Algebra"
@ -162,15 +148,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4,
"metadata": { "metadata": {
"id": "uW9j84_UDYof",
"colab_type": "code",
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7",
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
"height": 51 "height": 51
} },
"colab_type": "code",
"id": "uW9j84_UDYof",
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
" 2.866489 1.8229384 1.5478926]\n"
]
}
],
"source": [ "source": [
"import jax.numpy as jnp\n", "import jax.numpy as jnp\n",
"import jax.random as rand\n", "import jax.random as rand\n",
@ -184,24 +181,13 @@
"assert u.shape == (N, N)\n", "assert u.shape == (N, N)\n",
"assert vt.shape == (M, M)\n", "assert vt.shape == (M, M)\n",
"print(s)" "print(s)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
" 2.866489 1.8229384 1.5478926]\n"
],
"name": "stdout"
}
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "jCyKUn4-DCXn", "colab_type": "text",
"colab_type": "text" "id": "jCyKUn4-DCXn"
}, },
"source": [ "source": [
"## XLA Compilation" "## XLA Compilation"
@ -209,15 +195,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5,
"metadata": { "metadata": {
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8",
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
"height": 51 "height": 51
} },
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
" 0.13093245]\n"
]
}
],
"source": [ "source": [
"@jax.jit\n", "@jax.jit\n",
"def selu(x, alpha=1.67, lmbda=1.05):\n", "def selu(x, alpha=1.67, lmbda=1.05):\n",
@ -225,18 +222,21 @@
"x = jax.random.normal(key, (5000,))\n", "x = jax.random.normal(key, (5000,))\n",
"result = selu(x).block_until_ready()\n", "result = selu(x).block_until_ready()\n",
"print(result)" "print(result)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
" 0.13093245]\n"
],
"name": "stdout"
}
] ]
} }
] ],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "JAX Colab GPU Test",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
} }

View File

@ -2539,7 +2539,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertFalse(a._committed) self.assertFalse(a._committed)
out = f(a, a) out = f(a, a)
self.assertFalse(out._committed) self.assertFalse(out._committed)
self.assertEqual(out.device(), jax.devices()[0]) self.assertEqual(out.devices(), {jax.devices()[0]})
self.assertArraysEqual(out, a * 2) self.assertArraysEqual(out, a * 2)
with jax.default_device(jax.devices()[1]): with jax.default_device(jax.devices()[1]):
@ -2547,7 +2547,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertFalse(b._committed) self.assertFalse(b._committed)
out2 = f(b, b) out2 = f(b, b)
self.assertFalse(out2._committed) self.assertFalse(out2._committed)
self.assertEqual(out2.device(), jax.devices()[1]) self.assertEqual(out2.devices(), {jax.devices()[1]})
self.assertArraysEqual(out2, b * 2) self.assertArraysEqual(out2, b * 2)
def test_pjit_with_static_argnames(self): def test_pjit_with_static_argnames(self):
@ -2590,7 +2590,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
if jax.device_count() <= 1: if jax.device_count() <= 1:
self.skipTest('Test requires more >1 device.') self.skipTest('Test requires more >1 device.')
system_default_device = jnp.add(1, 1).device() system_default_device = list(jnp.add(1, 1).devices())[0]
test_device = jax.devices()[-1] test_device = jax.devices()[-1]
f = pjit(lambda x: x + 1) f = pjit(lambda x: x + 1)
@ -2733,7 +2733,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return x @ x.T return x @ x.T
def _check(out, expected_device, expected_out): def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device) self.assertEqual(out.devices(), {expected_device})
self.assertLen(out.sharding.device_set, 1) self.assertLen(out.sharding.device_set, 1)
self.assertArraysEqual(out, expected_out @ expected_out.T) self.assertArraysEqual(out, expected_out @ expected_out.T)
@ -2776,14 +2776,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
expected_device = jax.devices()[2] expected_device = jax.devices()[2]
final_out = pjit(lambda x: x * 3, device=expected_device)(out) final_out = pjit(lambda x: x * 3, device=expected_device)(out)
self.assertEqual(final_out.device(), expected_device) self.assertEqual(final_out.devices(), {expected_device})
self.assertLen(final_out.sharding.device_set, 1) self.assertLen(final_out.sharding.device_set, 1)
self.assertArraysEqual(final_out, inp * 6) self.assertArraysEqual(final_out, inp * 6)
@jtu.run_on_devices("tpu") @jtu.run_on_devices("tpu")
def test_pjit_with_backend_arg(self): def test_pjit_with_backend_arg(self):
def _check(out, expected_device, expected_out): def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device) self.assertEqual(out.devices(), {expected_device})
self.assertLen(out.sharding.device_set, 1) self.assertLen(out.sharding.device_set, 1)
self.assertArraysEqual(out, expected_out) self.assertArraysEqual(out, expected_out)
@ -3403,7 +3403,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
y = jax.device_put(x, jax.devices()[1]) y = jax.device_put(x, jax.devices()[1])
out2 = jax.jit(lambda x: x)(y) out2 = jax.jit(lambda x: x)(y)
self.assertIsInstance(out2.sharding, SingleDeviceSharding) self.assertIsInstance(out2.sharding, SingleDeviceSharding)
self.assertEqual(out2.device(), jax.devices()[1]) self.assertEqual(out2.devices(), {jax.devices()[1]})
out3 = jax.jit(lambda x: x * 2)(x) out3 = jax.jit(lambda x: x * 2)(x)
self.assertIsInstance(out3.sharding, SingleDeviceSharding) self.assertIsInstance(out3.sharding, SingleDeviceSharding)
@ -3411,7 +3411,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out4 = jax.jit(lambda x: x * 3, out4 = jax.jit(lambda x: x * 3,
out_shardings=SingleDeviceSharding(jax.devices()[1]))(x) out_shardings=SingleDeviceSharding(jax.devices()[1]))(x)
self.assertIsInstance(out4.sharding, SingleDeviceSharding) self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1]) self.assertEqual(out4.devices(), {jax.devices()[1]})
def test_none_out_sharding(self): def test_none_out_sharding(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
@ -3449,7 +3449,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
arr4 = jax.device_put(jnp.arange(8), jax.devices()[1]) arr4 = jax.device_put(jnp.arange(8), jax.devices()[1])
out4 = jnp.copy(arr4) out4 = jnp.copy(arr4)
self.assertIsInstance(out4.sharding, SingleDeviceSharding) self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1]) self.assertEqual(out4.devices(), {jax.devices()[1]})
def test_get_indices_cache(self): def test_get_indices_cache(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -3506,12 +3506,12 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Fill up the to_gspmd_sharding cache so that the next jit will miss it. # Fill up the to_gspmd_sharding cache so that the next jit will miss it.
out = jax.jit(identity, out = jax.jit(identity,
in_shardings=SingleDeviceSharding(jax.devices()[0]))(np_inp) in_shardings=SingleDeviceSharding(jax.devices()[0]))(np_inp)
self.assertEqual(out.device(), jax.devices()[0]) self.assertEqual(out.devices(), {jax.devices()[0]})
self.assertArraysEqual(out, np_inp) self.assertArraysEqual(out, np_inp)
out2 = jax.jit(identity, device=jax.devices()[0])( out2 = jax.jit(identity, device=jax.devices()[0])(
jax.device_put(np_inp, NamedSharding(mesh, P('x')))) jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
self.assertEqual(out2.device(), jax.devices()[0]) self.assertEqual(out2.devices(), {jax.devices()[0]})
self.assertArraysEqual(out2, np_inp) self.assertArraysEqual(out2, np_inp)
def test_jit_submhlo_cached(self): def test_jit_submhlo_cached(self):

View File

@ -147,12 +147,12 @@ class PythonPmapTest(jtu.JaxTestCase):
view = jnp.array(buf, copy=False) view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view) self.assertArraysEqual(sda[-1], view)
self.assertEqual(buf.device(), view.device()) self.assertSetEqual(buf.devices(), view.devices())
self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer()) self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer())
copy = jnp.array(buf, copy=True) copy = jnp.array(buf, copy=True)
self.assertArraysEqual(sda[-1], copy) self.assertArraysEqual(sda[-1], copy)
self.assertEqual(buf.device(), copy.device()) self.assertSetEqual(buf.devices(), copy.devices())
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer()) self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
def _getMeshShape(self, device_mesh_shape): def _getMeshShape(self, device_mesh_shape):
@ -869,7 +869,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# test that we can handle device movement on dispatch # test that we can handle device movement on dispatch
bufs = y._arrays[::-1] bufs = y._arrays[::-1]
sharding = jax.sharding.PmapSharding( sharding = jax.sharding.PmapSharding(
[b.device() for b in bufs], y.sharding.sharding_spec) [list(b.devices())[0] for b in bufs], y.sharding.sharding_spec)
y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs) y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs)
z = f(y) z = f(y)
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False) self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
@ -2769,7 +2769,7 @@ class ArrayTest(jtu.JaxTestCase):
self.assertEqual(s.replica_id, 0) self.assertEqual(s.replica_id, 0)
buffers = getattr(y, '_arrays') buffers = getattr(y, '_arrays')
self.assertEqual(len(buffers), len(devices)) self.assertEqual(len(buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices))) self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices)))
self.assertArraysEqual(y, jnp.stack(x)) self.assertArraysEqual(y, jnp.stack(x))
def test_device_put_sharded_pytree(self): def test_device_put_sharded_pytree(self):
@ -2781,12 +2781,12 @@ class ArrayTest(jtu.JaxTestCase):
self.assertIsInstance(y1, array.ArrayImpl) self.assertIsInstance(y1, array.ArrayImpl)
self.assertArraysEqual(y1, jnp.array([a for a, _ in x])) self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
y1_buffers = getattr(y1, '_arrays') y1_buffers = getattr(y1, '_arrays')
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices))) self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices)))
self.assertIsInstance(y2, array.ArrayImpl) self.assertIsInstance(y2, array.ArrayImpl)
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x])) self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
y2_buffers = getattr(y2, '_arrays') y2_buffers = getattr(y2, '_arrays')
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices))) self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices)))
def test_device_put_replicated(self): def test_device_put_replicated(self):
devices = jax.local_devices() devices = jax.local_devices()
@ -2796,7 +2796,7 @@ class ArrayTest(jtu.JaxTestCase):
self.assertIsInstance(y, array.ArrayImpl) self.assertIsInstance(y, array.ArrayImpl)
buffers = getattr(y, '_arrays') buffers = getattr(y, '_arrays')
self.assertEqual(len(buffers), len(devices)) self.assertEqual(len(buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices))) self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices)))
self.assertArraysEqual(y, np.stack([x for _ in devices])) self.assertArraysEqual(y, np.stack([x for _ in devices]))
def test_device_put_replicated_pytree(self): def test_device_put_replicated_pytree(self):
@ -2809,13 +2809,13 @@ class ArrayTest(jtu.JaxTestCase):
self.assertIsInstance(y1, array.ArrayImpl) self.assertIsInstance(y1, array.ArrayImpl)
y1_buffers = getattr(y1, '_arrays') y1_buffers = getattr(y1, '_arrays')
self.assertEqual(len(y1_buffers), len(devices)) self.assertEqual(len(y1_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices))) self.assertTrue(all(b.devices() == {d} for b, d in zip(y1_buffers, devices)))
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices])) self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
self.assertIsInstance(y2, array.ArrayImpl) self.assertIsInstance(y2, array.ArrayImpl)
y2_buffers = getattr(y2, '_arrays') y2_buffers = getattr(y2, '_arrays')
self.assertEqual(len(y2_buffers), len(devices)) self.assertEqual(len(y2_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices))) self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices)))
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices])) self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
def test_repr(self): def test_repr(self):
@ -3127,8 +3127,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.skipTest('Test requires >= 2 devices.') self.skipTest('Test requires >= 2 devices.')
def amap(f, xs): def amap(f, xs):
ys = [f(jax.device_put(x, x.device())) for x in xs] ys = [f(jax.device_put(x, list(x.devices())[0])) for x in xs]
return jax.device_put_sharded(ys, [y.device() for y in ys]) return jax.device_put_sharded(ys, [list(y.devices())[0] for y in ys])
# leading axis is batch dim (i.e. mapped/parallel dim), of size 2 # leading axis is batch dim (i.e. mapped/parallel dim), of size 2
x = jnp.array([[1., 0., 0.], x = jnp.array([[1., 0., 0.],

View File

@ -1026,7 +1026,8 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable) self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated) self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
self.assertEqual(key.device(), key._base_array.device()) with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
self.assertEqual(key.device(), key._base_array.device())
self.assertEqual(key.devices(), key._base_array.devices()) self.assertEqual(key.devices(), key._base_array.devices())
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes) self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer) self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)