rocm_jax/tests/array_test.py
Skye Wanderman-Milne 120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00

948 lines
36 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
import contextlib
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.util import prod, safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
from jax.experimental import PartitionSpec as P
from jax.experimental.serialize_executable import (
compile_and_serialize, load_compiled)
from jax._src import sharding
from jax._src import array
from jax._src import prng
from jax._src.lib import xla_extension_version
from jax.experimental import maps
from jax.config import config
config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xb.get_backend.cache_clear()
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xb.get_backend.cache_clear()
def create_array(shape, sharding, global_data=None):
if global_data is None:
global_data = np.arange(prod(shape)).reshape(shape)
return array.make_array_from_callback(
shape, sharding, lambda idx: global_data[idx]), global_data
@jtu.with_config(jax_array=True)
class JaxArrayTest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_jax_array_value(self, mesh_axes):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, global_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
s = sharding.NamedSharding(global_mesh, mesh_axes)
arr, global_input_data = create_array(global_input_shape, s)
self.assertEqual(arr.ndim, 2)
self.assertEqual(arr.size, 16)
self.assertEqual(arr.addressable_shards[0].index, expected_index[0])
self.assertEqual(arr.addressable_shards[1].index, expected_index[1])
replica_ids = [i.replica_id for i in arr.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in arr.addressable_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
for i, s in enumerate(arr.addressable_shards):
self.assertEqual(s.data.aval,
jax.ShapedArray(expected_shard_shape, s.data.dtype))
self.assertArraysEqual(s.data, global_input_data[s.index])
self.assertArraysEqual(s.data, arr.addressable_data(i))
for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
def test_addressable_data(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
s = sharding.NamedSharding(global_mesh, P(None))
arr, inp_data = create_array(shape, s)
for i in range(len(arr)):
self.assertArraysEqual(inp_data, arr.addressable_data(i))
def test_array_delete(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_single_device_array_usage_after_delete(self):
x = jnp.array([1, 2, 3])
x.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
_ = x + 1
def test_multi_device_array_usage_after_delete(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
_ = arr + 1
def test_device_put(self):
numpy_array = np.array([1, 2, 3])
arr = jax.device_put(numpy_array, jax.devices()[0])
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
self.assertArraysEqual(arr, numpy_array)
self.assertEqual(arr._committed, True)
for i in arr.addressable_shards:
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
self.assertEqual(i.replica_id, 0)
def test_device_put_array_delete(self):
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_array_device_get(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertArraysEqual(jax.device_get(arr), input_data)
def test_repr(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertStartsWith(repr(arr), "Array(")
def test_jnp_array(self):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.ArrayImpl)
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
self.assertEqual(arr._committed, False)
self.assertFalse(arr.weak_type)
def test_jnp_array_jit_add(self):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_jnp_array_jnp_add(self):
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_jnp_array_normal_add(self):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_array_sharded_astype(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
self.assertLen(arr_float32.addressable_shards, 8)
for i in arr_float32.addressable_shards:
self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))
def test_jnp_array_astype(self):
arr = jnp.array([1, 2, 3])
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
def test_sharded_add(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
b, _ = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x')))
out = a + b
expected = input_data + input_data
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])
def test_sharded_zeros_like(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
out = jnp.zeros_like(a)
expected = jnp.zeros(input_data.shape, dtype=int)
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])
def test_zeros_like(self):
a = jnp.array([1, 2, 3], dtype=np.int32)
out = jnp.zeros_like(a)
expected = np.zeros(a.shape, dtype=np.int32)
self.assertArraysEqual(out, expected)
self.assertTrue(dispatch.is_single_device_sharding(out.sharding))
def test_wrong_num_arrays(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
s = sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d)
for d in jax.local_devices()]
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 4'):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 16'):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
self.skipTest('Requires more than 4 devices')
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
# sharding device ids = {0, 1}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {2, 3}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
"Sharding contains devices {0, 1} that are not present in per-device "
"arrays. Per-device arrays contain devices {2, 3} that are not present "
"in the sharding."):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_more_devices_in_sharding_than_arrays(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
# Sharding device ids = {0, 1}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{1\} that are not present in per-device "
"arrays."):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_different_devices_in_arrays_than_sharding(self):
if jax.device_count() < 3:
self.skipTest('Requires more than 3 devices')
shape = (8, 2)
mesh = maps.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
# sharding device ids = {1, 2}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 1}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{2\} that are not present in per-device "
r"arrays. Per-device arrays contain devices \{0\} that are not present "
"in the sharding."):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
)
def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, pspec)
inp_data = np.arange(prod(shape)).reshape(shape)
str_expected_shard_shape = str(expected_shard_shape).replace(
r"(", r"\(").replace(r")", r"\)")
with self.assertRaisesRegex(
ValueError,
f"Expected shard shape {str_expected_shard_shape} doesn't match the "
"buffer shape"):
array.make_array_from_callback(shape, mps, lambda idx: inp_data)
def test_mismatch_dtype(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
s = sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.int32).reshape(shape)
indices = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
with self.assertRaisesRegex(
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y],
[a.device() for a in y._arrays])
sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysAllClose(i, j)
def test_array_iter_pmap_sharding_last_dim_sharded(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin, out_axes=1)(x)
for i, j in zip(iter(y), iter(np.sin(x).T)):
self.assertArraysAllClose(i, j)
def test_array_iter_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
for i, j in zip(iter(arr), iter(input_data)):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysEqual(i, j)
def test_array_iter_replicated_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P(None)))
for i, j in zip(iter(arr), iter(input_data)):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysEqual(i, j)
self.assertLen(i.sharding.device_set, 8)
self.assertTrue(
pxla.are_op_shardings_equal(
arr.sharding._to_xla_op_sharding(arr.ndim),
i.sharding._to_xla_op_sharding(i.ndim)))
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
# TODO(yashkatariya): `__getitem__` with a specific index takes the fast
# path after b/245667823 is fixed.
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, np.array([[4], [6]]))
p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
self.assertArraysEqual(p, input_data[:2])
def test_array_getitem_replicated_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, P(None)))
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, np.array([[4], [6]]))
self.assertLen(s.sharding.device_set, 8)
self.assertTrue(
pxla.are_op_shardings_equal(
arr.sharding._to_xla_op_sharding(arr.ndim),
s.sharding._to_xla_op_sharding(s.ndim)))
p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
self.assertArraysEqual(p, input_data[:2])
self.assertLen(s.sharding.device_set, 8)
self.assertTrue(
pxla.are_op_shardings_equal(
arr.sharding._to_xla_op_sharding(arr.ndim),
s.sharding._to_xla_op_sharding(s.ndim)))
def test_array_iter_mesh_pspec_sharding_single_device(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
single_dev = jax.devices()[1:2]
mesh = maps.Mesh(np.array(single_dev), ('x'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(mesh, P('x')))
for i, j in zip(arr, iter(input_data)):
self.assertArraysEqual(i, j)
self.assertEqual(i.device(), single_dev[0])
def test_array_shards_committed(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([1, 2, 3])
for s in x.addressable_shards:
self.assertEqual(s.data._committed, x._committed)
self.assertFalse(s.data._committed)
y = jax.device_put(x, jax.devices()[1])
for s in y.addressable_shards:
self.assertEqual(s.data._committed, y._committed)
self.assertTrue(s.data._committed)
def test_array_jnp_array_copy_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
c_arr = jnp.array(arr, copy=True)
self.assertArraysEqual(arr, c_arr)
self.assertEqual(arr._committed, c_arr._committed)
for a, c in safe_zip(arr.addressable_shards, c_arr.addressable_shards):
self.assertArraysEqual(a.data, c.data)
self.assertEqual(a.index, c.index)
self.assertEqual(a.replica_id, c.replica_id)
self.assertEqual(a.device, c.device)
self.assertNotEqual(a.data.unsafe_buffer_pointer(),
c.data.unsafe_buffer_pointer())
def test_array_device_buffer(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.NamedSharding(global_mesh, P('x', 'y')))
for a in arr.device_buffers:
self.assertIsInstance(a, array.ArrayImpl)
x = jnp.array([1, 2, 3])
self.assertIsInstance(x.device_buffer, array.ArrayImpl)
def test_shape_dtype_struct_sharding_jit(self):
mesh = jtu.create_global_mesh((8,), ('x'))
s = sharding.NamedSharding(mesh, P('x'))
x_dummy = jax.ShapeDtypeStruct(
shape=(16,),
dtype=jnp.dtype('float32'),
sharding=s)
def f(x):
return x * 2
c = jax.jit(f).lower(x_dummy).compile()
input_shardings, output_shardings = c.input_shardings, c.output_shardings
self.assertLen(input_shardings, 2)
self.assertEqual(input_shardings[1], {})
self.assertEqual(input_shardings[1], {})
self.assertTrue(
pxla.are_op_shardings_equal(
input_shardings[0][0]._to_xla_op_sharding(x_dummy.ndim),
s._to_xla_op_sharding(x_dummy.ndim)))
self.assertTrue(
pxla.are_op_shardings_equal(
output_shardings._to_xla_op_sharding(x_dummy.ndim),
s._to_xla_op_sharding(x_dummy.ndim)))
def test_shape_dtype_struct_sharding_pjit(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
s = sharding.NamedSharding(mesh, P('x', 'y'))
def f(x):
return x * 2.
x_dummy = jax.ShapeDtypeStruct(
shape=(8, 2),
dtype=jnp.dtype('float32'),
sharding=s)
c = pjit(f).lower(x_dummy).compile()
input_shardings, output_shardings = c.input_shardings, c.output_shardings
self.assertTrue(
pxla.are_op_shardings_equal(
input_shardings[0][0]._to_xla_op_sharding(x_dummy.ndim),
s._to_xla_op_sharding(x_dummy.ndim)))
self.assertTrue(
pxla.are_op_shardings_equal(
output_shardings._to_xla_op_sharding(x_dummy.ndim),
s._to_xla_op_sharding(x_dummy.ndim)))
# TODO(skyewm): remove this test when we can remove the workaround manual
# defragment API
@jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU
def test_defragment(self):
if xc._version < 105:
self.skipTest('Test needs jaxlib 0.3.25 or greater to pass')
# Create a few arrays
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
mpsharding = sharding.NamedSharding(global_mesh, P('x', 'y'))
arr1, data = create_array(shape, mpsharding)
arr2, _ = create_array(shape, mpsharding, data)
arr3, _ = create_array(shape, mpsharding, data)
# Delete one of them
arr2.delete()
# Defragment
xb.get_backend().defragment()
# Sanity check remaining arrays
self.assertArraysEqual(arr1, data)
self.assertArraysEqual(arr1 + arr3, data * 2)
# TODO(skyewm): check that defragmentation actually happened. I originally
# thought to do this with unsafe_buffer_pointer(), but that's not always the
# device memory address. Other ideas include causing enough fragmentation to
# OOM, and exposing allocator stats in Python.
@jtu.with_config(jax_array=True)
class ShardingTest(jtu.JaxTestCase):
def test_mesh_pspec_sharding_interface(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
pspec = P('y', 'x')
global_shape = (8, 4)
mp_sharding = sharding.NamedSharding(mesh, pspec)
di_map = mp_sharding.devices_indices_map(global_shape)
op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape))
device_assignment = mp_sharding._device_assignment
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER)
self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4])
self.assertListEqual(op_sharding.tile_assignment_devices,
[0, 2, 4, 6, 1, 3, 5, 7])
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_none_x", P(None, "x")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_op_sharding_indices(self, pspec):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, pspec)
ops = sharding.OpShardingSharding(
list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape)))
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
("mesh_fully_replicated", P(), (8, 4)),
)
def test_shard_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, pspec)
self.assertEqual(mps.shard_shape(shape), expected_shard_shape)
def test_uneven_shard_error(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
r"Sharding.*implies that array axis 1 is partitioned 2 times, but the "
r"dimension size is 3 \(full shape: \(8, 3\), per-dimension tiling "
r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3))
def test_pmap_sharding_hash_eq(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
shape = (2, 2)
num_elements = prod(shape)
inp_data = np.arange(num_elements).reshape(shape)
out = jax.pmap(lambda x: x)(inp_data)
self.assertIsInstance(out.sharding, sharding.PmapSharding)
# Populate the device_indices_map cache.
_ = out.sharding.devices_indices_map(shape)
cache_info1 = sharding.PmapSharding.devices_indices_map.cache_info()
inp_data2 = np.arange(num_elements, num_elements + num_elements).reshape(shape)
out2 = jax.pmap(lambda x: x)(inp_data2)
# Populate the device_indices_map cache.
_ = out2.sharding.devices_indices_map(shape)
cache_info2 = sharding.PmapSharding.devices_indices_map.cache_info()
self.assertGreater(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_is_compatible_error(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
mps = sharding.NamedSharding(mesh, P(None, ('mdl',), None, None))
new_mps = sharding.NamedSharding._from_parsed_pspec(
mps.mesh, mps._parsed_pspec)
with self.assertRaisesRegex(
ValueError,
r"Sharding NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 2"):
new_mps.is_compatible_aval(shape)
def test_is_subclass(self):
# array version of api_test.py::APITest::test_is_subclass
self.assertTrue(issubclass(array.ArrayImpl, jnp.ndarray))
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
def test_op_sharding_sharding_repr(self):
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 1, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
op.replicate_on_last_tile_dim = True
s = sharding.OpShardingSharding(jax.devices(), op)
self.assertEqual(
repr(s),
'OpShardingSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
'last_tile_dim_replicate})')
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.REPLICATED
s2 = sharding.OpShardingSharding(jax.devices(), op2)
self.assertEqual(repr(s2), 'OpShardingSharding({replicated})')
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (4, 2), (), False),
("mesh_x", P("x"), (4, 2), (1,), False),
("mesh_y", P("y"), (4, 2), (0,), True),
("mesh_none_y", P(None, "y"), (4, 2), (0,), False),
("mesh_none_x", P(None, "x"), (4, 2), (1,), True),
("mesh_xy", P(("x", "y")), (8, 1), (), False),
("mesh_fully_replicated", P(), (4, 2), None, False),
)
def test_devices_sharding_op_sharding_lowering(
self, pspec, shape, axes, transpose):
value_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, pspec)
devices_sharding = sharding.PositionalSharding(jax.devices())
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
if transpose:
devices_sharding = devices_sharding.T
op1 = mps._to_xla_op_sharding(len(value_shape))
op2 = devices_sharding._to_xla_op_sharding(len(value_shape))
self.assertEqual(mps.shard_shape(value_shape),
devices_sharding.shard_shape(value_shape))
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
def test_devices_sharding_respects_init_mesh_shape(self):
value_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, P('x', 'y'))
devices_sharding = sharding.PositionalSharding(mesh.devices)
op1 = mps._to_xla_op_sharding(len(value_shape))
op2 = devices_sharding._to_xla_op_sharding(len(value_shape))
self.assertEqual(mps.shard_shape(value_shape),
devices_sharding.shard_shape(value_shape))
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
def test_pmap_sharding_repr(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
out = jax.pmap(lambda x: x)(jnp.arange(2.))
str(out.sharding) # doesn't crash
repr(out.sharding) # doesn't crash
@parameterized.named_parameters(
('sharded_dim_0', (4, 2), 0),
('sharded_dim_1_0', (4, 2), 1),
('sharded_dim_2', (4, 2, 4), 2),
('sharded_dim_1_1', (2, 4), 1)
)
def test_default_pmap_sharding(self, shape, sharded_dim):
if jax.device_count() < 4:
self.skipTest('Test needs >= 4 devices.')
ps = sharding.PmapSharding.default(shape, sharded_dim)
inp = jnp.arange(np.prod(shape)).reshape(shape)
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment)
self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec)
@jtu.with_config(jax_array=True)
class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected
@parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5))
@jtu.skip_on_devices("gpu")
def test_random_bits_is_pure_map_1d(self, num_devices):
@jax.jit
def f(x):
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
32, x.shape)
return bits + x
mesh = jtu.create_global_mesh((num_devices,), ('x',))
s = sharding.NamedSharding(mesh, P('x'))
n = num_devices ** 2
global_x = jnp.arange(n).astype('uint32')
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
# check computation is fully partitioned and without any communication
jax.config.update('jax_threefry_partitionable', True)
unopt_txt = f.lower(x).as_text(dialect='hlo')
opt_txt = f.lower(x).compile().as_text()
self.assertIn( f'[{n}]', unopt_txt)
self.assertNotIn(f'[{n}]', opt_txt)
self.assertNotIn('all-reduce', opt_txt)
self.assertNotIn('collective-permute', opt_txt)
# check against single-device reference
y = f(x)
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)
# check against single-device previous implementation reference
jax.config.update('jax_threefry_partitionable', False)
y_ref2 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref2)
@parameterized.named_parameters(
{"testcase_name": f"_{mesh_shape}_{pspec}",
"mesh_shape": mesh_shape, "pspec": pspec}
for mesh_shape in [(3, 2), (4, 2), (2, 3)]
for pspec in [P('x', None), P(None, 'y'), P('x', 'y')])
@jtu.skip_on_devices("gpu")
def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec):
@jax.jit
def f(x):
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
32, x.shape)
return bits + x
global_shape = tuple(np.square(mesh_shape))
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
s = sharding.NamedSharding(mesh, pspec)
n = prod(global_shape)
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
# check computation is fully partitioned and without any communication
jax.config.update('jax_threefry_partitionable', True)
unopt_txt = f.lower(x).as_text(dialect='hlo')
opt_txt = f.lower(x).compile().as_text()
global_shape_fmt = ','.join(str(x) for x in global_shape)
self.assertIn( f'[{global_shape_fmt}]', unopt_txt)
self.assertNotIn(f'[{global_shape_fmt}]', opt_txt)
self.assertNotIn('all-reduce', opt_txt)
self.assertNotIn('collective-permute', opt_txt)
# check against single-device reference
y = f(x)
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)
# check against single-device previous implementation reference
jax.config.update('jax_threefry_partitionable', False)
y_ref2 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref2)
@unittest.skipIf(xla_extension_version < 106,
'Pjit pickling requires newer jaxlib.')
def test_pickle_pjit_lower(self):
example_exe = jax.jit(lambda x: x * x).lower(
jax.core.ShapedArray(
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
# Skip if CompileOptions is not available. This is true on
# CPU/GPU/Cloud TPU for now.
try:
example_exe.compile_options()
except Exception as e:
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
raise unittest.SkipTest('Serialization not supported')
raise e
def fun(x):
return x * x
with maps.Mesh(np.array(jax.devices()), ('data',)):
lowered = pjit(
fun,
in_axis_resources=P('data'),
out_axis_resources=P(None, 'data'),
).lower(jax.ShapedArray(shape=(8, 8), dtype=np.float32))
def verify_serialization(lowered):
serialized, in_tree, out_tree = compile_and_serialize(lowered)
compiled = load_compiled(serialized, in_tree, out_tree)
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
verify_serialization(lowered)
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())