rocm_jax/tests/array_test.py

1575 lines
61 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import jaxlib_extension_version
from jax._src.lib.mlir import dialects, ir
from jax._src.util import safe_zip
from jax._src.mesh import AxisType, AbstractMesh
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (
_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map,
NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding,
SdyArraySharding)
from jax.experimental.pjit import pjit
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
from jax._src import array
from jax._src import prng
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
def create_array(shape, sharding, global_data=None):
if global_data is None:
global_data = np.arange(math.prod(shape)).reshape(shape)
return array.make_array_from_callback(
shape, sharding, lambda idx: global_data[idx]), global_data
class JaxArrayTest(jtu.JaxTestCase):
def test_array_impl_name(self):
self.assertEqual(array.ArrayImpl.__name__, "ArrayImpl")
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_jax_array_value(self, mesh_axes):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, global_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertTrue(dispatch.is_single_device_sharding(s.data.sharding))
self.assertArraysEqual(s.data, global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
if arr._npy_value is not None:
self.assertArraysEqual(arr._npy_value, global_data)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convenient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True)
global_input_shape = (8, 2)
s = jax.sharding.NamedSharding(global_mesh, mesh_axes)
arr, global_input_data = create_array(global_input_shape, s)
self.assertEqual(arr.ndim, 2)
self.assertEqual(arr.size, 16)
self.assertEqual(arr.addressable_shards[0].index, expected_index[0])
self.assertEqual(arr.addressable_shards[1].index, expected_index[1])
replica_ids = [i.replica_id for i in arr.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in arr.addressable_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
for i, s in enumerate(arr.addressable_shards):
self.assertEqual(s.data.aval,
core.ShapedArray(expected_shard_shape, s.data.dtype))
self.assertArraysEqual(s.data, global_input_data[s.index])
self.assertArraysEqual(s.data, arr.addressable_data(i))
for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
def test_addressable_data(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
s = jax.sharding.NamedSharding(global_mesh, P(None))
arr, inp_data = create_array(shape, s)
for i in range(len(arr)):
self.assertArraysEqual(inp_data, arr.addressable_data(i))
def test_array_delete(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_single_device_array_usage_after_delete(self):
x = jnp.array([1, 2, 3])
x.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
_ = x + 1
def test_multi_device_array_usage_after_delete(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32),
jax.sharding.NamedSharding(global_mesh, P('x')))
arr.delete()
with self.assertRaisesRegex(
RuntimeError, r'Array has been deleted with shape=int32\[16\].'):
_ = arr + 1
def test_device_put(self):
numpy_array = np.array([1, 2, 3])
arr = jax.device_put(numpy_array, jax.devices()[0])
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
self.assertArraysEqual(arr, numpy_array)
self.assertEqual(arr._committed, True)
for i in arr.addressable_shards:
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
self.assertEqual(i.replica_id, 0)
def test_device_put_array_delete(self):
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_array_device_get(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertArraysEqual(jax.device_get(arr), input_data)
def test_repr(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertStartsWith(repr(arr), "Array(")
def test_empty_repr(self):
shape = (0, 5)
dtype = 'float32'
x = jnp.empty(shape, dtype)
self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})")
def test_jnp_array(self):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.ArrayImpl)
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
self.assertEqual(arr._committed, False)
self.assertFalse(arr.weak_type)
def test_jnp_array_jit_add(self):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
def test_jnp_array_jnp_add(self):
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
def test_jnp_array_normal_add(self):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
def test_array_sharded_astype(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
self.assertLen(arr_float32.addressable_shards, 8)
for i in arr_float32.addressable_shards:
self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))
def test_jnp_array_astype(self):
arr = jnp.array([1, 2, 3])
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
def test_array_delete_idempotent(self):
mesh = jtu.create_mesh((2,), ('x',))
arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x')))
arr.delete()
self.assertTrue(arr.is_deleted())
arr.delete() # Run delete again to check if it's idempotent.
self.assertTrue(arr.is_deleted())
def test_sharded_add(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
b, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x')))
out = a + b
expected = input_data + input_data
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])
def test_sharded_zeros_like(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
out = jnp.zeros_like(a)
expected = jnp.zeros(input_data.shape, dtype=a.dtype)
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])
def test_zeros_like(self):
a = jnp.array([1, 2, 3], dtype=np.int32)
out = jnp.zeros_like(a)
expected = np.zeros(a.shape, dtype=np.int32)
self.assertArraysEqual(out, expected)
self.assertTrue(dispatch.is_single_device_sharding(out.sharding))
def test_wrong_num_arrays(self):
if jax.device_count() < 4:
self.skipTest('Requires more than 4 devices')
shape = (8, 2)
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
devices = jax.local_devices()[:2] # Taking up to 2 devices
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d) for d in devices]
with self.assertRaisesRegex(
ValueError,
r'Expected 2 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 1'):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:1], committed=True)
for buf, d in zip(list(bufs), jax.local_devices()[2:4]):
bufs.append(jax.device_put(buf, d))
with self.assertRaisesRegex(
ValueError,
r'Expected 2 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 4'):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
self.skipTest('Requires more than 4 devices')
shape = (8, 2)
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
# sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {2, 3}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
"Sharding contains devices {0, 1} that are not present in per-device "
"arrays. Per-device arrays contain devices {2, 3} that are not present "
"in the sharding."):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_different_devices_in_arrays_than_sharding(self):
if jax.device_count() < 3:
self.skipTest('Requires more than 3 devices')
shape = (8, 2)
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
# sharding device ids = {1, 2}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 1}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{2\} that are not present in per-device "
r"arrays. Per-device arrays contain devices \{0\} that are not present "
"in the sharding."):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_duplicated_devices_in_arrays(self):
shape = (8, 2)
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
# Sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
ValueError,
'When making an array from single-device arrays, the input arrays must'
' be from distinct devices'):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
("mesh_replicated", P(()), (8, 4)),
)
def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
inp_data = np.arange(5)
str_expected_shard_shape = str(expected_shard_shape).replace(
r"(", r"\(").replace(r")", r"\)")
with self.assertRaisesRegex(
ValueError,
f"Expected shard shape {str_expected_shard_shape} doesn't match the "
"single device array shape"):
array.make_array_from_callback(shape, mps, lambda idx: inp_data)
def test_mismatch_dtype(self):
shape = (8, 2)
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
indices = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
with self.assertRaisesRegex(
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([list(a.devices())[0] for a in y],
y.sharding._device_assignment,
allow_object_dtype=True)
sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysAllClose(i, j)
def test_array_iter_pmap_sharding_last_dim_sharded(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin, out_axes=1)(x)
for i, j in zip(iter(y), iter(np.sin(x).T)):
self.assertArraysAllClose(i, j)
def test_array_iter_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
for i, j in zip(iter(arr), iter(input_data)):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysEqual(i, j)
def test_array_iter_replicated_multi_device(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
for i, j in zip(iter(arr), iter(input_data)):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysEqual(i, j)
self.assertLen(i.sharding.device_set, 8)
self.assertTrue(
op_shardings.are_op_shardings_equal(
arr.sharding._to_xla_hlo_sharding(arr.ndim),
i.sharding._to_xla_hlo_sharding(i.ndim)))
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, input_data[2:4, 0:1])
p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
self.assertArraysEqual(p, input_data[:2])
def test_array_getitem_compile_multi_device_sharding(self):
def _check(out, inp, shard_shape):
self.assertArraysEqual(out, inp)
self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape)
self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding)
global_mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
input_shape = (4, 4, 2)
arr, np_inp = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z')))
_check(arr[:, -1, :], np_inp[:, -1, :], (2, 1))
_check(arr[0, 0, 0], np_inp[0, 0, 0], ())
_check(arr[-1, -1, :], np_inp[-1, -1, :], (1,))
_check(arr[:, 1, 0], np_inp[:, 1, 0], (2,))
_check(arr[:, :, :], np_inp[:, :, :], (2, 2, 1))
_check(arr[3, :, :], np_inp[3, :, :], (2, 1))
_check(arr[-1, -1, -1], np_inp[-1, -1, -1], ())
_check(arr[2, -1, :], np_inp[2, -1, :], (1,))
_check(arr[2, 3, 1], np_inp[2, 3, 1], ())
_check(arr[-1], np_inp[-1], (2, 1))
_check(arr[:], np_inp[:], (2, 2, 1))
_check(arr[np.array(0), :, :], np_inp[np.array(0), :, :], (2, 1))
_check(arr[jnp.array(0), :, :], np_inp[jnp.array(0), :, :], (2, 1))
_check(arr[0, :2, 1], np_inp[0, :2, 1], (2,))
_check(arr[:, 1::2], np_inp[:, 1::2], (2, 2, 1))
_check(arr[:, -1:, :], np_inp[:, -1:, :], (2, 1, 1))
_check(arr[0:6:1], np_inp[0:6:1], (2, 2, 1))
_check(arr[:4], np_inp[:4], (2, 2, 1))
_check(arr[::-1], np_inp[::-1], (2, 2, 1))
_check(arr[1], np_inp[1], (2, 1))
def test_array_getitem_replicated_multi_device(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, np.array([[4], [6]]))
self.assertLen(s.sharding.device_set, 8)
self.assertTrue(
op_shardings.are_op_shardings_equal(
arr.sharding._to_xla_hlo_sharding(arr.ndim),
s.sharding._to_xla_hlo_sharding(s.ndim)))
p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
self.assertArraysEqual(p, input_data[:2])
self.assertLen(s.sharding.device_set, 8)
self.assertTrue(
op_shardings.are_op_shardings_equal(
arr.sharding._to_xla_hlo_sharding(arr.ndim),
s.sharding._to_xla_hlo_sharding(s.ndim)))
def test_array_iter_mesh_pspec_sharding_single_device(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
single_dev = jax.devices()[1:2]
mesh = jax.sharding.Mesh(np.array(single_dev), ('x'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(mesh, P('x')))
for i, j in zip(arr, iter(input_data)):
self.assertArraysEqual(i, j)
self.assertEqual(i.devices(), {single_dev[0]})
def test_array_shards_committed(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([1, 2, 3])
for s in x.addressable_shards:
self.assertEqual(s.data._committed, x._committed)
self.assertFalse(s.data._committed)
y = jax.device_put(x, jax.devices()[1])
for s in y.addressable_shards:
self.assertEqual(s.data._committed, y._committed)
self.assertTrue(s.data._committed)
def test_array_jnp_array_copy_multi_device(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
c_arr = jnp.array(arr, copy=True)
self.assertArraysEqual(arr, c_arr)
self.assertEqual(arr._committed, c_arr._committed)
for a, c in safe_zip(arr.addressable_shards, c_arr.addressable_shards):
self.assertArraysEqual(a.data, c.data)
self.assertEqual(a.index, c.index)
self.assertEqual(a.replica_id, c.replica_id)
self.assertEqual(a.device, c.device)
self.assertNotEqual(a.data.unsafe_buffer_pointer(),
c.data.unsafe_buffer_pointer())
def test_array_addressable_shards(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
for a in arr.addressable_shards:
self.assertIsInstance(a.data, array.ArrayImpl)
x = jnp.array([1, 2, 3])
self.assertIsInstance(x.addressable_data(0), array.ArrayImpl)
def test_array_not_hashable(self):
x = jnp.arange(4)
with self.assertRaisesRegex(TypeError, "unhashable type"):
hash(x)
with self.assertRaisesRegex(TypeError, "unhashable type"):
jax.jit(hash)(x)
with self.assertRaisesRegex(TypeError, "unhashable type"):
jax.vmap(hash)(x)
def test_shape_dtype_struct_sharding_jit(self):
mesh = jtu.create_mesh((8,), ('x'))
s = jax.sharding.NamedSharding(mesh, P('x'))
x_dummy = jax.ShapeDtypeStruct(
shape=(16,),
dtype=jnp.dtype('float32'),
sharding=s)
def f(x):
return x * 2
c = jax.jit(f).lower(x_dummy).compile()
input_shardings, output_shardings = c.input_shardings, c.output_shardings
self.assertLen(input_shardings, 2)
self.assertEqual(input_shardings[1], {})
self.assertEqual(input_shardings[1], {})
self.assertTrue(
op_shardings.are_op_shardings_equal(
input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
self.assertTrue(
op_shardings.are_op_shardings_equal(
output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
def test_shape_dtype_struct_sharding_pjit(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
def f(x):
return x * 2.
x_dummy = jax.ShapeDtypeStruct(
shape=(8, 2),
dtype=jnp.dtype('float32'),
sharding=s)
c = pjit(f).lower(x_dummy).compile()
input_shardings, output_shardings = c.input_shardings, c.output_shardings
self.assertTrue(
op_shardings.are_op_shardings_equal(
input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
self.assertTrue(
op_shardings.are_op_shardings_equal(
output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
# TODO(b/399879011): GPU is the only platform that has an implementation for
# this, which exists in py_client.cc. Ideally, this would be replaced with
# some kind of auto-defrag-on-OOM.
@jtu.run_on_devices('gpu')
def test_defragment(self):
# Since the GPU implementation is in py_client.cc, it cannot be exposed via
# the PjRt C API.
if xb.using_pjrt_c_api():
self.skipTest('Manual defragment not exposed via PJRT C API')
# Create a few arrays
global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',))
shape = (8, 2)
mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',))
arr1, data = create_array(shape, mpsharding)
arr2, _ = create_array(shape, mpsharding, data)
arr3, _ = create_array(shape, mpsharding, data)
# Delete one of them
arr2.delete()
# Defragment.
xb.get_backend().defragment()
# Sanity check remaining arrays
self.assertArraysEqual(arr1, data)
self.assertArraysEqual(arr1 + arr3, data * 2)
# TODO(skyewm): check that defragmentation actually happened. I originally
# thought to do this with unsafe_buffer_pointer(), but that's not always the
# device memory address. Other ideas include causing enough fragmentation to
# OOM, and exposing allocator stats in Python.
def test_on_device_size_in_bytes(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
a, _ = create_array(
(8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
shard_size = a.addressable_shards[0].data.on_device_size_in_bytes()
self.assertGreaterEqual(shard_size, 4 * 2)
self.assertEqual(shard_size * len(a.global_shards),
a.on_device_size_in_bytes())
def test_array_is_ready(self):
x = jax.device_put(jnp.arange(8.), jax.devices()[0])
x.is_ready() # doesn't crash
def test_process_allgather_single_host(self):
x = jnp.arange(8.)
out = multihost_utils.process_allgather(x, tiled=True)
self.assertEqual(out.shape, x.shape)
self.assertArraysEqual(out, x)
out = multihost_utils.process_allgather(x)
self.assertEqual(out.shape, (1, x.shape[0]))
self.assertArraysEqual(out, np.expand_dims(x, axis=0))
@jtu.sample_product(
dtype=jtu.dtypes.all,
shape=[(), (10), (2, 3)],
)
@jtu.run_on_devices("cpu")
def test_buffer_protocol(self, dtype, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jax.device_put(x)
if dtype == jax.dtypes.bfloat16:
with self.assertRaisesRegex(
BufferError,
'Buffers of type BF16 are not supported by the Python buffer '
'protocol.'
):
memoryview(y)
return
x_bytes = memoryview(x).tobytes()
y_bytes = memoryview(y).tobytes()
self.assertEqual(x_bytes, y_bytes)
@jtu.run_on_devices("cpu")
def test_buffer_protocol_deletion(self):
rng = jtu.rand_default(self.rng())
x = rng((3, 4), np.float32)
y = jax.device_put(x)
x_bytes = memoryview(x).tobytes()
y_view = memoryview(y)
# The array does not actually get deleted until any external reference is
# dropped. Arguably we should make calling delete() in these circumstances
# return an error instead, but that would be a behavior change for existing
# users.
y.delete()
y_bytes = y_view.tobytes()
self.assertEqual(x_bytes, y_bytes)
def test_array_copy_to_host_async(self):
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = pjit(lambda: jnp.arange(8.),
out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))()
self.assertLen(x.sharding.device_set, 4)
x.copy_to_host_async() # doesn't crash
self.assertArraysEqual(np.arange(8.), x)
def test_array_fully_replicated_shard(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
arr, inp_data = create_array(
inp_shape, jax.sharding.NamedSharding(global_mesh, P()))
fs = arr._fully_replicated_shard()
self.assertEqual(fs.shape, inp_shape)
self.assertTrue(dispatch.is_single_device_sharding(fs.sharding))
self.assertArraysEqual(fs, inp_data)
self.assertArraysEqual(arr.addressable_data(0), inp_data)
def test_shard_array_to_fully_replicated(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(global_mesh, P())
arr = jnp.arange(16)
self.assertFalse(arr._committed)
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
out = jax.jit(lambda x: x * 2, in_shardings=sharding)(arr)
self.assertTrue(out.sharding.is_fully_replicated)
self.assertArraysEqual(out, arr * 2)
def test_fully_replicated_donated_array_is_deleted(self):
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(global_mesh, P())
arr = jnp.arange(16)
arr_copy = arr.copy()
self.assertFalse(arr._committed)
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
out = jax.jit(lambda x: x * 2, in_shardings=sharding, donate_argnums=0)(arr)
self.assertTrue(out.sharding.is_fully_replicated)
self.assertArraysEqual(out, arr_copy * 2)
self.assertTrue(arr.is_deleted())
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
def test_shards_have_correct_dtype(self, dtype):
x = jnp.ones((), dtype=dtype)
for shard in x.addressable_shards:
self.assertEqual(shard.data.dtype, dtype)
def test_make_array_from_callback_global_array(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P())
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, sharding)
out = jax.make_array_from_callback(np_inp.shape, sharding,
lambda idx: arr[idx])
self.assertArraysEqual(out, arr)
self.assertEqual(out.sharding, sharding)
sharding2 = NamedSharding(mesh, P('x', 'y'))
arr2 = jax.device_put(np_inp, sharding2)
out2 = jax.make_array_from_callback(np_inp.shape, sharding2,
lambda idx: arr2[idx])
self.assertArraysEqual(out2, arr2)
self.assertEqual(out2.sharding, sharding2)
def test_make_array_from_process_data_single_host_data_sharding(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
data = np.ones((256, 512))
s = jax.NamedSharding(mesh, P('x'))
result = jax.make_array_from_process_local_data(s, data)
self.assertArraysEqual(result, data)
self.assertEqual(result.sharding, s)
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
@jtu.run_on_devices("gpu")
def test_pinned_host_npy_value_doesnt_cache(self, dtype):
# see https://github.com/jax-ml/jax/issues/26216
d_tensor = jnp.array(0, dtype=dtype)
d_sharding = d_tensor.sharding
h_sharding = d_sharding.with_memory_kind("pinned_host")
h_tensor = jax.device_put(d_tensor, h_sharding)
np.array(h_tensor)
self.assertIsNone(h_tensor._npy_value)
@config.enable_empty_arrays(True)
def test_make_array_from_single_device_arrays_no_dtype_error(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
'If the Array has no addressable shards, `dtype` must be provided via '
'the `dtype` argument to `jax.make_array_from_single_device_arrays`.'):
jax.make_array_from_single_device_arrays((8, 2), s, [])
def test_make_array_from_single_device_arrays_bad_dtype_error(self):
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
shape = (8, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
with self.assertRaisesRegex(
ValueError,
'If `dtype` is provided to `jax.make_array_from_single_device_arrays`, '
'it must match the dtype of the addressable shards.'):
jax.make_array_from_single_device_arrays(
shape, s, [arr], dtype=jnp.float32)
class ShardingTest(jtu.JaxTestCase):
def test_mesh_pspec_sharding_interface(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
pspec = P('y', 'x')
global_shape = (8, 4)
mp_sharding = jax.sharding.NamedSharding(mesh, pspec)
di_map = mp_sharding.devices_indices_map(global_shape)
hlo_sharding = mp_sharding._to_xla_hlo_sharding(len(global_shape))
device_assignment = mp_sharding._device_assignment
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
allow_object_dtype=True)
self.assertTrue(hlo_sharding.is_tiled())
self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
[0, 2, 4, 6, 1, 3, 5, 7])
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_util_clear_cache(self):
mesh = jtu.create_mesh((1,), ('x',))
s = NamedSharding(mesh, P())
s.devices_indices_map((8,))
jax.clear_caches()
s.devices_indices_map((8,))
c = common_devices_indices_map.cache_info()
self.assertEqual(c.currsize, 1)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_none_x", P(None, "x")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_op_sharding_indices(self, pspec):
shape = (8, 4)
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
ops = jax.sharding.GSPMDSharding(
list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape)))
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
("mesh_fully_replicated", P(), (8, 4)),
)
def test_shard_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
self.assertEqual(mps.shard_shape(shape), expected_shard_shape)
def test_uneven_shard_error(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
r"Sharding.*implies that array axis 1 is partitioned 2 times, but the "
r"dimension size is 3 \(full shape: \(8, 3\), per-dimension tiling "
r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3))
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pmap_sharding_hash_eq(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
shape = (2, 2)
num_elements = math.prod(shape)
inp_data = np.arange(num_elements).reshape(shape)
out = jax.pmap(lambda x: x)(inp_data)
self.assertIsInstance(out.sharding, jax.sharding.PmapSharding)
# Populate the device_indices_map cache.
_ = out.sharding.devices_indices_map(shape)
cache_info1 = pmap_sharding_devices_indices_map.cache_info()
inp_data2 = np.arange(num_elements, num_elements + num_elements).reshape(shape)
out2 = jax.pmap(lambda x: x)(inp_data2)
# Populate the device_indices_map cache.
_ = out2.sharding.devices_indices_map(shape)
cache_info2 = pmap_sharding_devices_indices_map.cache_info()
self.assertGreater(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_is_compatible_error(self):
shape = (8, 2)
mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None))
with self.assertRaisesRegex(
ValueError,
r"Sharding NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\)"
' is only valid for values of rank at least 4, but was applied to a'
' value of rank 2'):
mps.check_compatible_aval(shape)
def test_is_subclass(self):
# array version of api_test.py::APITest::test_is_subclass
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
def test_gspmd_sharding_repr(self):
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 1, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
op.replicate_on_last_tile_dim = True
s = jax.sharding.GSPMDSharding(jax.devices(), op)
# memory kind also appears in the repr but only for TPU.
self.assertIn(
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
'last_tile_dim_replicate}', repr(s))
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.REPLICATED
s2 = jax.sharding.GSPMDSharding(jax.devices(), op2)
# memory kind also appears in the repr but only for TPU.
self.assertIn('GSPMDSharding({replicated}', repr(s2))
def test_positional_sharding_fully_replicated(self):
sharding = PositionalSharding(jax.devices())
jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (4, 2), (), False),
("mesh_x", P("x"), (4, 2), (1,), False),
("mesh_y", P("y"), (4, 2), (0,), True),
("mesh_none_y", P(None, "y"), (4, 2), (0,), False),
("mesh_none_x", P(None, "x"), (4, 2), (1,), True),
("mesh_xy", P(("x", "y")), (8, 1), (), False),
("mesh_fully_replicated", P(), (4, 2), None, False),
)
def test_positional_sharding_op_sharding_lowering(
self, pspec, shape, axes, transpose):
value_shape = (8, 4)
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
devices = jax.local_devices()[:8] # Taking up to 8 devices
devices_sharding = jax.sharding.PositionalSharding(devices)
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
if transpose:
devices_sharding = devices_sharding.T
op1 = mps._to_xla_hlo_sharding(len(value_shape))
op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape))
self.assertEqual(mps.shard_shape(value_shape),
devices_sharding.shard_shape(value_shape))
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
def test_positional_sharding_aval_compatible(self):
if jax.device_count() < 2:
self.skipTest('Requires >=2 devices')
sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count())
x = jax.random.uniform(jax.random.key(42), (256, 20, 1000))
with self.assertRaisesRegex(
ValueError,
'Sharding PositionalSharding.*is only valid for values of rank 2, but'
' was applied to a value of rank 3'):
jax.lax.with_sharding_constraint(x, sharding)
@parameterized.named_parameters(
("2d_mesh_x_y", (4, 2), P("x", "y")),
("2d_mesh_x", (4, 2), P("x")),
("2d_mesh_y", (4, 2), P("y")),
("2d_mesh_none_y", (4, 2), P(None, "y")),
("2d_mesh_none_x", (4, 2), P(None, "x")),
("2d_mesh_xy", (4, 2), P(("x", "y"))),
("2d_mesh_none_xy", (4, 2), P(None, ("x", "y"))),
("2d_mesh_x_none", (2, 1), P(('x',), None)),
("2d_mesh_fully_replicated", (4, 2), P()),
("3d_mesh_none_none_z", (2, 2, 2), P(None, None, 'z')),
("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)),
("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)),
("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))),
("3d_mesh_x_none_yz", (2, 2, 2), P('x', None, ('y', 'z'))),
("3d_mesh_none_x_yz", (2, 2, 2), P(None, 'x', ('y', 'z'))),
("3d_mesh_xy_z", (2, 2, 2), P(('x', 'y'), 'z')),
("3d_mesh_xy_none_z", (2, 2, 2), P(('x', 'y'), None, 'z')),
("3d_mesh_x_y_z", (2, 2, 2), P('x', 'y', 'z')),
("3d_mesh_xz_y", (2, 2, 2), P(('x', 'z'), 'y')),
("3d_mesh_xz_none_y", (2, 2, 2), P(('x', 'z'), None, 'y')),
("3d_mesh_y_none_xz", (2, 2, 2), P('y', None, ('x', 'z'))),
("3d_mesh_none_y_xz", (2, 2, 2), P(None, 'y', ('x', 'z'))),
("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')),
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
("3d_mesh_x_none_none", (2, 1, 1), P('x', None, None)),
)
def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
ndim = len(mesh_shape)
mesh = jtu.create_mesh(
mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z'))
mps = jax.sharding.NamedSharding(mesh, pspec)
original_op_sharding = mps._to_xla_hlo_sharding(ndim)
ps = _op_sharding_to_pos_sharding(original_op_sharding,
mps._device_assignment)
out_op_sharding = ps._to_xla_hlo_sharding(ndim)
self.assertTrue(op_shardings.are_op_shardings_equal(
original_op_sharding, out_op_sharding))
@parameterized.named_parameters(
("2d_mesh_x", (1, 1), P("x", "y")),
("2d_mesh_x_y", (4, 2), P("x", "y")),
("2d_mesh_empty", (2, 1), P()),
("2d_mesh_p_none", (2, 1), P(None)),
("2d_mesh_none_none", (2, 1), P(None, None)),
("2d_mesh_tuple_empty", (2, 1), P((),)),
("2d_mesh_x_none", (2, 1), P(('x',), None)),
("2d_mesh_xy_none", (2, 1), P(('x', 'y'), None)),
("2d_mesh_x_tuple_empty", (2, 1), P('x', (), (), ())),
("2d_mesh_3_tuple_empty", (2, 1), P((), (), ())),
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
("3d_mesh2_x_y_none", (1, 1, 4), P('x', 'y', None)),
("3d_mesh2_xy_none", (1, 1, 4), P(('x', 'y'), None)),
)
def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
if len(mesh_shape) == 2:
axis_names = ('x', 'y')
elif len(mesh_shape) == 3:
axis_names = ('x', 'y', 'z')
else:
axis_names = ('x',)
mesh = jtu.create_mesh(mesh_shape, axis_names)
mps = jax.sharding.NamedSharding(mesh, pspec)
shape = (8, 2, 4)
mps_op_sharding = mps._to_xla_hlo_sharding(len(shape))
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
self.assertEqual(mps.is_fully_replicated, ops_ifr)
ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment)
self.assertEqual(ps.is_fully_replicated,
op_shardings.is_op_sharding_replicated(
ps._to_xla_hlo_sharding(len(shape))))
def test_devices_sharding_respects_init_mesh_shape(self):
value_shape = (8, 4)
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
devices_sharding = jax.sharding.PositionalSharding(mesh.devices)
op1 = mps._to_xla_hlo_sharding(len(value_shape))
op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape))
self.assertEqual(mps.shard_shape(value_shape),
devices_sharding.shard_shape(value_shape))
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
def test_pmap_sharding_repr(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
out = jax.pmap(lambda x: x)(jnp.arange(2.))
str(out.sharding) # doesn't crash
repr(out.sharding) # doesn't crash
def test_positional_sharding_repr(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
s = jax.sharding.PositionalSharding(jax.devices()).reshape(jax.device_count(), 1)
repr(s) # doesn't crash
str(s) # doesn't crash
def test_pspec_tuple(self):
pspec = P('x', 'y', 'z')
self.assertEqual(pspec, ('x', 'y', 'z'))
self.assertEqual(pspec.index('z'), 2)
self.assertEqual(hash(P(None, 'x', 'y', 'z')), hash(P((), 'x', 'y', 'z')))
@parameterized.named_parameters(
('sharded_dim_0', (4, 2), 0),
('sharded_dim_1_0', (4, 2), 1),
('sharded_dim_2', (4, 2, 4), 2),
('sharded_dim_1_1', (2, 4), 1)
)
def test_default_pmap_sharding(self, shape, sharded_dim):
if jax.device_count() < 4:
self.skipTest('Test needs >= 4 devices.')
ps = jax.sharding.PmapSharding.default(shape, sharded_dim)
inp = jnp.arange(math.prod(shape)).reshape(shape)
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment)
self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec)
def test_default_pmap_sharding_with_devices(self):
if jax.device_count() < 4:
self.skipTest('Test needs >= 4 devices.')
devs = jax.devices()
new_order = (devs[0], devs[3], devs[2], devs[1])
ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order)
self.assertEqual(ps._device_assignment, new_order)
def test_default_pmap_sharding_replicated(self):
x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32)
x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(x)
ps = jax.sharding.PmapSharding.default(
shape=(8,), sharded_dim=None,
devices=jax.local_devices())
self.assertEqual(x.sharding, ps)
def test_mesh_repr(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
mesh_repr = repr(mesh)
self.assertIn('device_ids', mesh_repr)
self.assertIn('axis_names', mesh_repr)
def test_are_shardings_equivalent(self):
mesh = jtu.create_mesh((1,), ('x'))
mesh2 = jtu.create_mesh((2, 1), ('x', 'y'))
s1 = jax.sharding.NamedSharding(mesh, P('x'))
s2 = jax.sharding.SingleDeviceSharding(jax.devices()[0])
self.assertTrue(s1.is_equivalent_to(s2, 2))
s3 = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())).sharding
s4 = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())).sharding
self.assertTrue(s3.is_equivalent_to(s4, 2))
self.assertFalse(s1.is_equivalent_to(s3, 2))
self.assertFalse(s2.is_equivalent_to(s3, 2))
s5 = jax.sharding.NamedSharding(mesh2, P('x', 'y'))
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.REPLICATED
s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1)
s7 = jax.sharding.GSPMDSharding(jax.devices(), op1)
# The OpSharding is replicated but the Sharding itself are on different
# devices.
self.assertFalse(s6.is_equivalent_to(s7, 2))
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.OTHER
op2.tile_assignment_devices = [0, 1]
op2.tile_assignment_dimensions = [2, 1]
s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2)
self.assertTrue(s1.is_equivalent_to(s6, 2))
self.assertTrue(s5.is_equivalent_to(s8, 2))
self.assertFalse(s5.is_equivalent_to(s2, 2))
s9 = jax.sharding.NamedSharding(mesh2, P('y'))
op3 = xc.OpSharding()
op3.type = xc.OpSharding.Type.OTHER
op3.tile_assignment_devices = [0, 1]
op3.tile_assignment_dimensions = [1, 1, 2]
op3.replicate_on_last_tile_dim = True
s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3)
self.assertTrue(s9.is_equivalent_to(s10, 2))
def test_devices_indices_map_good_error_message(self):
shape = (1, 2)
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
"Sharding.*implies that array axis 0 is partitioned 2 times, but the "
"dimension size is 1"):
s.devices_indices_map(shape)
def test_scalar_input_wrong_pspec(self):
mesh = jtu.create_mesh((1, ), ('x'))
shape = ()
s = jax.sharding.NamedSharding(mesh, P('x'))
with self.assertRaisesRegex(
ValueError,
r"For scalars the PartitionSpec should be P()"):
s.check_compatible_aval(shape)
def test_mesh_caching_during_construction(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Requires >=2 devices")
mesh1 = jax.sharding.Mesh(jax.devices(), 'x')
mesh2 = jax.sharding.Mesh(jax.devices(), 'x')
self.assertIs(mesh1, mesh2)
def test_mesh_str(self):
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
self.assertEqual(
str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto))"
)
def test_make_array_from_callback_error(self):
mesh_shape = (2, 3)
global_shape = tuple(np.square(mesh_shape))
mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True)
pspec = P('x', 'y')
sharding = jax.sharding.NamedSharding(mesh, pspec)
n = math.prod(global_shape)
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
def f(arr):
return array.make_array_from_callback(arr.shape, sharding, lambda i: arr[i])
out = f(global_x)
self.assertEqual(out.shape, global_shape)
msg = "jax.make_array_from_callback cannot be called within a traced context"
with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, msg):
jax.jit(f)(global_x)
def test_make_array_from_single_device_arrays_error(self):
x = jnp.arange(10)
sharding = x.sharding
def f(x):
return jax.make_array_from_single_device_arrays(x.shape, sharding, [x])
msg = "jax.make_array_from_single_device_arrays requires a list of concrete arrays"
with self.assertRaisesRegex(ValueError, msg):
jax.jit(f)(x)
def test_make_array_from_single_device_arrays_nonlist_error(self):
x = jnp.arange(10)
sharding = x.sharding
def f(x):
return jax.make_array_from_single_device_arrays(x.shape, sharding, x)
msg = "jax.make_array_from_single_device_arrays `arrays` argument"
with self.assertRaisesRegex(TypeError, msg):
jax.jit(f)(x)
def test_make_array_from_single_device_arrays_tuple(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (8, 8)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(math.prod(shape)).reshape(shape)
arrays = tuple(
jax.device_put(inp_data[index], d)
for d, index in s.addressable_devices_indices_map(shape).items())
jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash
def test_make_array_from_single_device_arrays_bad_inputs(self):
x = jnp.arange(10)
mesh = jtu.create_mesh((2,), ('x',))
s = jax.sharding.NamedSharding(mesh, P('x'))
x = jax.device_put(x, s)
msg = ("When making an array from single-device arrays the input arrays "
"must have one shard each. An argument array had 2 shard\\(s\\).")
with self.assertRaisesRegex(ValueError, msg):
jax.make_array_from_single_device_arrays(x.shape, s, [x, x])
def test_gspmd_sharding_hash_eq(self):
mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))
ns = NamedSharding(mesh, P('x', 'y', 'z'))
x1 = GSPMDSharding(mesh._flat_devices_tuple, ns._to_xla_hlo_sharding(3))
x2 = GSPMDSharding.get_replicated(mesh._flat_devices_tuple)
self.assertEqual(x1, x2)
self.assertEqual(hash(x1), hash(x2))
def test_device_attr(self):
# For single-device arrays, x.device returns the device
x = jnp.ones((2, 10))
self.assertEqual(x.device, list(x.devices())[0])
# For sharded arrays, x.device returns the sharding
mesh = jtu.create_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x = jax.device_put(x, sharding)
self.assertEqual(x.device, sharding)
def test_to_device(self):
device = jax.devices()[-1]
mesh = jtu.create_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x = jnp.ones((2, 10))
x_device = x.to_device(device)
x_sharding = x.to_device(sharding)
self.assertEqual(x_device.device, device)
self.assertEqual(x_sharding.device, sharding)
def test_mesh_with_axis_name_none(self):
with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'):
jax.sharding.Mesh(jax.devices(), (None, 'x'))
def test_mesh_axis_types_mismatch(self):
with self.assertRaisesRegex(
ValueError,
'Number of axis names should match the number of axis_types'):
jtu.create_mesh((2, 1), ('x', 'y'),
axis_types=jax.sharding.AxisType.Auto)
with self.assertRaisesRegex(
ValueError,
'Number of axis names should match the number of axis_types'):
jax.sharding.AbstractMesh((2, 1), ('x', 'y'),
axis_types=jax.sharding.AxisType.Auto)
with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"):
AbstractMesh((2,), ('x',), axis_types=("explicit",))
with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"):
AbstractMesh((2,), ('x',), axis_types="explicit")
with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"):
AbstractMesh((2, 2), ('x', 'y'),
axis_types=("explicit", AxisType.Explicit))
def test_make_mesh_axis_types(self):
Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual
mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
self.assertEqual(mesh1, mesh2)
mesh = jax.make_mesh((1, 1), ('x', 'y'))
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Auto: ('x', 'y')})
mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
axis_types=(Explicit, Auto, Manual))
self.assertDictEqual(
mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',),
AxisType.Manual: ('z',)})
self.assertEqual(mesh.explicit_axes, ('x',))
self.assertEqual(mesh.auto_axes, ('y',))
self.assertEqual(mesh.manual_axes, ('z',))
mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
axis_types=(Explicit, Explicit, Manual))
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y'),
AxisType.Manual: ('z',)})
mesh = jax.make_mesh((1, 1), ('x', 'y'), axis_types=(Explicit, Explicit))
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y')})
mesh = jax.make_mesh((1,), 'model', axis_types=Manual)
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Manual: ('model',)})
with self.assertRaisesRegex(
ValueError,
'Number of axis names should match the number of axis_types'):
jax.make_mesh((1, 1), ('data', 'model'), axis_types=Explicit)
mesh1 = jax.make_mesh((1, 1, 1, 1, 1), ('a', 'b', 'c', 'd', 'e'),
axis_types=(Explicit, Auto, Auto, Explicit, Explicit))
mesh2 = jax.make_mesh((1, 1, 1, 1, 1), ('a', 'b', 'c', 'd', 'e'),
axis_types=(Explicit, Auto, Auto, Explicit, Auto))
self.assertNotEqual(mesh1, mesh2)
self.assertNotEqual(hash(mesh1), hash(mesh2))
def test_memory_kind_with_abstract_mesh(self):
if jaxlib_extension_version < 326:
self.skipTest('Requires jaxlib_extension_version >= 326')
abstract_mesh = AbstractMesh((2,), ('x',))
ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host')
self.assertEqual(ns.memory_kind, 'pinned_host')
ns = NamedSharding(abstract_mesh, P())
self.assertIsNone(ns.memory_kind)
with self.assertRaisesRegex(
ValueError, 'Got invalid memory kind'):
NamedSharding(abstract_mesh, P(), memory_kind='weird_device')
@jtu.with_config(jax_use_shardy_partitioner=True)
class ShardyShardingTest(jtu.JaxTestCase):
def test_long_axis_names(self):
mesh = jtu.create_mesh((2, 2, 2), ('sequence', 'data', 'model'))
s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model'))
sdy_sharding = s._to_sdy_sharding(3)
self.assertEqual(
sdy_sharding,
SdyArraySharding(
mesh.shape_tuple,
[SdyDimSharding(
('sequence', 'data'), True),
SdyDimSharding(('model',), True),
SdyDimSharding([], True)]))
with ir.Context() as ctx:
dialects.sdy.register_dialect(ctx)
self.assertEqual(
str(sdy_sharding.build()),
'#sdy.sharding<mesh<["sequence"=2, "data"=2, "model"=2]>,'
' [{"sequence", "data"}, {"model"}, {}]>',
)
def test_unconstrained(self):
mesh = jtu.create_mesh((8,), ('x',))
s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x'))
sdy_sharding = s._to_sdy_sharding(3)
self.assertEqual(
sdy_sharding,
SdyArraySharding(
mesh.shape_tuple,
[SdyDimSharding([], True),
SdyDimSharding([], False),
SdyDimSharding(('x',), True)]))
with ir.Context() as ctx:
dialects.sdy.register_dialect(ctx)
self.assertEqual(
str(sdy_sharding.build()),
'#sdy.sharding<mesh<["x"=8]>, [{}, {?}, {"x"}]>')
class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected
@parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5))
@jtu.skip_on_devices("gpu")
def test_random_bits_is_pure_map_1d(self, num_devices):
@jax.jit
def f(x):
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
32, x.shape)
return bits + x
mesh = jtu.create_mesh((num_devices,), ('x',), iota_order=True)
s = jax.sharding.NamedSharding(mesh, P('x'))
n = num_devices ** 2
global_x = jnp.arange(n).astype('uint32')
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
# check computation is fully partitioned and without any communication
with jax.threefry_partitionable(True):
unopt_txt = f.lower(x).as_text(dialect='hlo')
opt_txt = f.lower(x).compile().as_text()
self.assertIn( f'[{n}]', unopt_txt)
self.assertNotIn(f'[{n}]', opt_txt)
self.assertNotIn('all-reduce', opt_txt)
self.assertNotIn('collective-permute', opt_txt)
# check against single-device reference
y = f(x)
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)
@parameterized.named_parameters(
{"testcase_name": f"_{mesh_shape}_{pspec}",
"mesh_shape": mesh_shape, "pspec": pspec}
for mesh_shape in [(3, 2), (4, 2), (2, 3)]
for pspec in [P('x', None), P(None, 'y'), P('x', 'y')])
@jtu.skip_on_devices("gpu")
def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec):
@jax.jit
def f(x):
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
32, x.shape)
return bits + x
global_shape = tuple(np.square(mesh_shape))
mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True)
s = jax.sharding.NamedSharding(mesh, pspec)
n = math.prod(global_shape)
global_x = np.arange(n).astype('uint32').reshape(global_shape)
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
# check computation is fully partitioned and without any communication
with jax.threefry_partitionable(True):
unopt_txt = f.lower(x).as_text(dialect='hlo')
opt_txt = f.lower(x).compile().as_text()
global_shape_fmt = ','.join(str(x) for x in global_shape)
self.assertIn( f'[{global_shape_fmt}]', unopt_txt)
self.assertNotIn(f'[{global_shape_fmt}]', opt_txt)
self.assertNotIn('all-reduce', opt_txt)
self.assertNotIn('collective-permute', opt_txt)
# check against single-device reference
y = f(x)
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)
def test_empty_mesh_creation(self):
mesh = jax.sharding.Mesh(devices=np.empty((), dtype=object), axis_names=[])
self.assertTrue(mesh.empty)
self.assertEqual(mesh.size, 0)
abstract_mesh = mesh.abstract_mesh
self.assertTrue(abstract_mesh.empty)
self.assertEqual(abstract_mesh.size, 0)
abstract_mesh2 = jax.sharding.AbstractMesh((), ())
self.assertTrue(abstract_mesh2.empty)
self.assertEqual(abstract_mesh2.size, 0)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())