mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
1511 lines
58 KiB
Python
1511 lines
58 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for Array."""
|
|
|
|
import contextlib
|
|
import math
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import op_shardings
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.lib.mlir import dialects, ir
|
|
from jax._src.util import safe_zip
|
|
from jax._src.mesh import AxisTypes
|
|
from jax._src.sharding import common_devices_indices_map
|
|
from jax._src.sharding_impls import (
|
|
_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map,
|
|
NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding,
|
|
SdyArraySharding)
|
|
from jax.experimental.pjit import pjit
|
|
from jax.experimental import multihost_utils
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax._src import array
|
|
from jax._src import prng
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
jtu.request_cpu_devices(8)
|
|
|
|
with contextlib.suppress(ImportError):
|
|
import pytest
|
|
pytestmark = pytest.mark.multiaccelerator
|
|
|
|
|
|
def create_array(shape, sharding, global_data=None):
|
|
if global_data is None:
|
|
global_data = np.arange(math.prod(shape)).reshape(shape)
|
|
|
|
return array.make_array_from_callback(
|
|
shape, sharding, lambda idx: global_data[idx]), global_data
|
|
|
|
|
|
class JaxArrayTest(jtu.JaxTestCase):
|
|
|
|
def test_array_impl_name(self):
|
|
self.assertEqual(array.ArrayImpl.__name__, "ArrayImpl")
|
|
|
|
@parameterized.named_parameters(
|
|
("mesh_x_y", P("x", "y")),
|
|
("mesh_x", P("x")),
|
|
("mesh_y", P("y")),
|
|
("mesh_none_y", P(None, "y")),
|
|
("mesh_xy", P(("x", "y"))),
|
|
("mesh_fully_replicated", P()),
|
|
)
|
|
def test_jax_array_value(self, mesh_axes):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, global_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes))
|
|
for s in arr.addressable_shards:
|
|
self.assertTrue(dispatch.is_single_device_sharding(s.data.sharding))
|
|
self.assertArraysEqual(s.data, global_data[s.index])
|
|
self.assertArraysEqual(arr._value, global_data)
|
|
if arr._npy_value is not None:
|
|
self.assertArraysEqual(arr._npy_value, global_data)
|
|
|
|
@parameterized.named_parameters(
|
|
("mesh_x_y", P("x", "y"),
|
|
# There are more slices but for convenient purposes, checking for only
|
|
# 2. The indices + shard_shape + replica_id should be unique enough.
|
|
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
|
|
(2, 1),
|
|
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
|
("mesh_x", P("x"),
|
|
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
|
|
(2, 2),
|
|
[0, 1, 0, 1, 0, 1, 0, 1], False),
|
|
("mesh_y", P("y"),
|
|
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
|
|
(4, 2),
|
|
[0, 0, 1, 1, 2, 2, 3, 3], False),
|
|
("mesh_none_y", P(None, "y"),
|
|
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
|
|
(8, 1),
|
|
[0, 0, 1, 1, 2, 2, 3, 3], False),
|
|
("mesh_xy", P(("x", "y")),
|
|
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
|
|
(1, 2),
|
|
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
|
("mesh_fully_replicated", P(),
|
|
((slice(None), slice(None)), (slice(None), slice(None))),
|
|
(8, 2),
|
|
[0, 1, 2, 3, 4, 5, 6, 7], True),
|
|
)
|
|
def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
|
expected_replica_ids, expected_is_fully_replicated):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True)
|
|
global_input_shape = (8, 2)
|
|
s = jax.sharding.NamedSharding(global_mesh, mesh_axes)
|
|
arr, global_input_data = create_array(global_input_shape, s)
|
|
self.assertEqual(arr.ndim, 2)
|
|
self.assertEqual(arr.size, 16)
|
|
self.assertEqual(arr.addressable_shards[0].index, expected_index[0])
|
|
self.assertEqual(arr.addressable_shards[1].index, expected_index[1])
|
|
replica_ids = [i.replica_id for i in arr.addressable_shards]
|
|
self.assertListEqual(replica_ids, expected_replica_ids)
|
|
self.assertListEqual([i.device.id for i in arr.addressable_shards],
|
|
[0, 1, 2, 3, 4, 5, 6, 7])
|
|
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
|
|
for i, s in enumerate(arr.addressable_shards):
|
|
self.assertEqual(s.data.aval,
|
|
core.ShapedArray(expected_shard_shape, s.data.dtype))
|
|
self.assertArraysEqual(s.data, global_input_data[s.index])
|
|
self.assertArraysEqual(s.data, arr.addressable_data(i))
|
|
|
|
for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
|
|
self.assertEqual(g.device, l.device)
|
|
self.assertEqual(g.index, l.index)
|
|
self.assertEqual(g.replica_id, l.replica_id)
|
|
self.assertEqual(g.data.aval, l.data.aval)
|
|
self.assertArraysEqual(g.data, l.data)
|
|
|
|
def test_addressable_data(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
shape = (8, 2)
|
|
s = jax.sharding.NamedSharding(global_mesh, P(None))
|
|
arr, inp_data = create_array(shape, s)
|
|
for i in range(len(arr)):
|
|
self.assertArraysEqual(inp_data, arr.addressable_data(i))
|
|
|
|
def test_array_delete(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, _ = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
arr.delete()
|
|
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
|
arr._check_if_deleted()
|
|
self.assertIsNone(arr._npy_value)
|
|
self.assertIsNone(arr._arrays)
|
|
|
|
def test_single_device_array_usage_after_delete(self):
|
|
x = jnp.array([1, 2, 3])
|
|
x.delete()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
|
_ = x + 1
|
|
|
|
def test_multi_device_array_usage_after_delete(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
shape = (8, 2)
|
|
arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32),
|
|
jax.sharding.NamedSharding(global_mesh, P('x')))
|
|
arr.delete()
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r'Array has been deleted with shape=int32\[16\].'):
|
|
_ = arr + 1
|
|
|
|
def test_device_put(self):
|
|
numpy_array = np.array([1, 2, 3])
|
|
arr = jax.device_put(numpy_array, jax.devices()[0])
|
|
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
|
self.assertArraysEqual(arr, numpy_array)
|
|
self.assertEqual(arr._committed, True)
|
|
for i in arr.addressable_shards:
|
|
self.assertArraysEqual(i.data, numpy_array)
|
|
self.assertEqual(i.device, jax.devices()[0])
|
|
self.assertEqual(i.index, (slice(None),))
|
|
self.assertEqual(i.replica_id, 0)
|
|
|
|
def test_device_put_array_delete(self):
|
|
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
|
arr.delete()
|
|
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
|
arr._check_if_deleted()
|
|
self.assertIsNone(arr._npy_value)
|
|
self.assertIsNone(arr._arrays)
|
|
|
|
def test_array_device_get(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
self.assertArraysEqual(jax.device_get(arr), input_data)
|
|
|
|
def test_repr(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, _ = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
self.assertStartsWith(repr(arr), "Array(")
|
|
|
|
def test_empty_repr(self):
|
|
shape = (0, 5)
|
|
dtype = 'float32'
|
|
x = jnp.empty(shape, dtype)
|
|
self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})")
|
|
|
|
def test_jnp_array(self):
|
|
arr = jnp.array([1, 2, 3])
|
|
self.assertIsInstance(arr, array.ArrayImpl)
|
|
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
|
|
self.assertEqual(arr._committed, False)
|
|
self.assertFalse(arr.weak_type)
|
|
|
|
def test_jnp_array_jit_add(self):
|
|
a = jnp.array([1, 2, 3])
|
|
b = jnp.array([4, 5, 6])
|
|
arr = jax.jit(lambda x, y: x + y)(a, b)
|
|
self.assertIsInstance(arr, array.ArrayImpl)
|
|
self.assertArraysEqual(arr, np.array([5, 7, 9]))
|
|
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
|
|
|
def test_jnp_array_jnp_add(self):
|
|
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
|
|
self.assertIsInstance(arr, array.ArrayImpl)
|
|
self.assertArraysEqual(arr, np.array([5, 7, 9]))
|
|
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
|
|
|
def test_jnp_array_normal_add(self):
|
|
a = jnp.array([1, 2, 3])
|
|
b = jnp.array([4, 5, 6])
|
|
arr = a + b
|
|
self.assertIsInstance(arr, array.ArrayImpl)
|
|
self.assertArraysEqual(arr, np.array([5, 7, 9]))
|
|
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
|
|
|
def test_array_sharded_astype(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
arr_float32 = arr.astype(jnp.float32)
|
|
self.assertEqual(arr_float32.dtype, np.float32)
|
|
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
|
|
self.assertLen(arr_float32.addressable_shards, 8)
|
|
for i in arr_float32.addressable_shards:
|
|
self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))
|
|
|
|
def test_jnp_array_astype(self):
|
|
arr = jnp.array([1, 2, 3])
|
|
arr_float32 = arr.astype(jnp.float32)
|
|
self.assertEqual(arr_float32.dtype, np.float32)
|
|
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
|
|
|
|
def test_array_delete_idempotent(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x')))
|
|
|
|
arr.delete()
|
|
self.assertTrue(arr.is_deleted())
|
|
|
|
arr.delete() # Run delete again to check if it's idempotent.
|
|
self.assertTrue(arr.is_deleted())
|
|
|
|
def test_sharded_add(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
a, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
b, _ = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x')))
|
|
out = a + b
|
|
expected = input_data + input_data
|
|
self.assertArraysEqual(out, expected)
|
|
self.assertLen(out.addressable_shards, 8)
|
|
for i in out.addressable_shards:
|
|
self.assertArraysEqual(i.data, expected[i.index])
|
|
|
|
def test_sharded_zeros_like(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
a, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
out = jnp.zeros_like(a)
|
|
expected = jnp.zeros(input_data.shape, dtype=a.dtype)
|
|
self.assertArraysEqual(out, expected)
|
|
self.assertLen(out.addressable_shards, 8)
|
|
for i in out.addressable_shards:
|
|
self.assertArraysEqual(i.data, expected[i.index])
|
|
|
|
def test_zeros_like(self):
|
|
a = jnp.array([1, 2, 3], dtype=np.int32)
|
|
out = jnp.zeros_like(a)
|
|
expected = np.zeros(a.shape, dtype=np.int32)
|
|
self.assertArraysEqual(out, expected)
|
|
self.assertTrue(dispatch.is_single_device_sharding(out.sharding))
|
|
|
|
def test_wrong_num_arrays(self):
|
|
if jax.device_count() < 4:
|
|
self.skipTest('Requires more than 4 devices')
|
|
shape = (8, 2)
|
|
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
|
devices = jax.local_devices()[:2] # Taking up to 2 devices
|
|
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
di_map = s.devices_indices_map(shape)
|
|
bufs = [jax.device_put(inp_data[di_map[d]], d) for d in devices]
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r'Expected 2 per-device arrays \(this is how many devices are addressable '
|
|
r'by the sharding\), but got 1'):
|
|
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:1], committed=True)
|
|
|
|
for buf, d in zip(list(bufs), jax.local_devices()[2:4]):
|
|
bufs.append(jax.device_put(buf, d))
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r'Expected 2 per-device arrays \(this is how many devices are addressable '
|
|
r'by the sharding\), but got 4'):
|
|
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
|
|
|
def test_arrays_not_in_device_assignment(self):
|
|
if jax.device_count() < 4:
|
|
self.skipTest('Requires more than 4 devices')
|
|
shape = (8, 2)
|
|
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
|
# sharding device ids = {0, 1}
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
# _arrays device ids = {2, 3}
|
|
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Addressable devices and per-device arrays devices do not match. "
|
|
"Sharding contains devices {0, 1} that are not present in per-device "
|
|
"arrays. Per-device arrays contain devices {2, 3} that are not present "
|
|
"in the sharding."):
|
|
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
|
|
|
def test_different_devices_in_arrays_than_sharding(self):
|
|
if jax.device_count() < 3:
|
|
self.skipTest('Requires more than 3 devices')
|
|
shape = (8, 2)
|
|
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
|
|
# sharding device ids = {1, 2}
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
# _arrays device ids = {0, 1}
|
|
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Addressable devices and per-device arrays devices do not match. "
|
|
r"Sharding contains devices \{2\} that are not present in per-device "
|
|
r"arrays. Per-device arrays contain devices \{0\} that are not present "
|
|
"in the sharding."):
|
|
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
|
|
|
def test_duplicated_devices_in_arrays(self):
|
|
if xc._version <= 274:
|
|
self.skipTest('Test requires jaxlib version 275')
|
|
shape = (8, 2)
|
|
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
|
# Sharding device ids = {0, 1}
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
# _arrays device ids = {0, 0}
|
|
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'When making an array from single-device arrays, the input arrays must'
|
|
' be from distinct devices'):
|
|
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
|
|
|
@parameterized.named_parameters(
|
|
("mesh_x_y", P("x", "y"), (2, 2)),
|
|
("mesh_x", P("x"), (2, 4)),
|
|
("mesh_y", P("y"), (4, 4)),
|
|
("mesh_none_y", P(None, "y"), (8, 2)),
|
|
("mesh_none_x", P(None, "x"), (8, 1)),
|
|
("mesh_xy", P(("x", "y")), (1, 4)),
|
|
("mesh_replicated", P(()), (8, 4)),
|
|
)
|
|
def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
|
|
shape = (8, 4)
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
mps = jax.sharding.NamedSharding(mesh, pspec)
|
|
inp_data = np.arange(5)
|
|
|
|
str_expected_shard_shape = str(expected_shard_shape).replace(
|
|
r"(", r"\(").replace(r")", r"\)")
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
f"Expected shard shape {str_expected_shard_shape} doesn't match the "
|
|
"single device array shape"):
|
|
array.make_array_from_callback(shape, mps, lambda idx: inp_data)
|
|
|
|
def test_mismatch_dtype(self):
|
|
shape = (8, 2)
|
|
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
|
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
|
indices = s.devices_indices_map(shape)
|
|
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Input buffers to `Array` must have matching dtypes. "
|
|
"Got int32, expected float32"):
|
|
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
|
|
|
def test_array_iter_pmap_sharding(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test requires >= 2 devices.')
|
|
|
|
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
|
|
y = jax.pmap(jnp.sin)(x)
|
|
self.assertArraysEqual([list(a.devices())[0] for a in y],
|
|
y.sharding._device_assignment,
|
|
allow_object_dtype=True)
|
|
|
|
sin_x = iter(np.sin(x))
|
|
for i, j in zip(iter(y), sin_x):
|
|
self.assertIsInstance(i, array.ArrayImpl)
|
|
self.assertArraysAllClose(i, j)
|
|
|
|
def test_array_iter_pmap_sharding_last_dim_sharded(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test requires >= 2 devices.')
|
|
|
|
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
|
|
y = jax.pmap(jnp.sin, out_axes=1)(x)
|
|
|
|
for i, j in zip(iter(y), iter(np.sin(x).T)):
|
|
self.assertArraysAllClose(i, j)
|
|
|
|
def test_array_iter_mesh_pspec_sharding_multi_device(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
|
|
for i, j in zip(iter(arr), iter(input_data)):
|
|
self.assertIsInstance(i, array.ArrayImpl)
|
|
self.assertArraysEqual(i, j)
|
|
|
|
def test_array_iter_replicated_multi_device(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
|
|
|
|
for i, j in zip(iter(arr), iter(input_data)):
|
|
self.assertIsInstance(i, array.ArrayImpl)
|
|
self.assertArraysEqual(i, j)
|
|
self.assertLen(i.sharding.device_set, 8)
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
arr.sharding._to_xla_hlo_sharding(arr.ndim),
|
|
i.sharding._to_xla_hlo_sharding(i.ndim)))
|
|
|
|
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
|
|
s = arr[2:4, 0:1]
|
|
self.assertIsInstance(s, array.ArrayImpl)
|
|
self.assertArraysEqual(s, input_data[2:4, 0:1])
|
|
|
|
p = arr[:2]
|
|
self.assertIsInstance(p, array.ArrayImpl)
|
|
self.assertArraysEqual(p, input_data[:2])
|
|
|
|
def test_array_getitem_compile_multi_device_sharding(self):
|
|
def _check(out, inp, shard_shape):
|
|
self.assertArraysEqual(out, inp)
|
|
self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape)
|
|
self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding)
|
|
|
|
global_mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
|
input_shape = (4, 4, 2)
|
|
arr, np_inp = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z')))
|
|
|
|
_check(arr[:, -1, :], np_inp[:, -1, :], (2, 1))
|
|
_check(arr[0, 0, 0], np_inp[0, 0, 0], ())
|
|
_check(arr[-1, -1, :], np_inp[-1, -1, :], (1,))
|
|
_check(arr[:, 1, 0], np_inp[:, 1, 0], (2,))
|
|
_check(arr[:, :, :], np_inp[:, :, :], (2, 2, 1))
|
|
_check(arr[3, :, :], np_inp[3, :, :], (2, 1))
|
|
_check(arr[-1, -1, -1], np_inp[-1, -1, -1], ())
|
|
_check(arr[2, -1, :], np_inp[2, -1, :], (1,))
|
|
_check(arr[2, 3, 1], np_inp[2, 3, 1], ())
|
|
_check(arr[-1], np_inp[-1], (2, 1))
|
|
_check(arr[:], np_inp[:], (2, 2, 1))
|
|
_check(arr[np.array(0), :, :], np_inp[np.array(0), :, :], (2, 1))
|
|
_check(arr[jnp.array(0), :, :], np_inp[jnp.array(0), :, :], (2, 1))
|
|
_check(arr[0, :2, 1], np_inp[0, :2, 1], (2,))
|
|
_check(arr[:, 1::2], np_inp[:, 1::2], (2, 2, 1))
|
|
_check(arr[:, -1:, :], np_inp[:, -1:, :], (2, 1, 1))
|
|
_check(arr[0:6:1], np_inp[0:6:1], (2, 2, 1))
|
|
_check(arr[:4], np_inp[:4], (2, 2, 1))
|
|
_check(arr[::-1], np_inp[::-1], (2, 2, 1))
|
|
_check(arr[1], np_inp[1], (2, 1))
|
|
|
|
def test_array_getitem_replicated_multi_device(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
|
|
|
|
s = arr[2:4, 0:1]
|
|
self.assertIsInstance(s, array.ArrayImpl)
|
|
self.assertArraysEqual(s, np.array([[4], [6]]))
|
|
self.assertLen(s.sharding.device_set, 8)
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
arr.sharding._to_xla_hlo_sharding(arr.ndim),
|
|
s.sharding._to_xla_hlo_sharding(s.ndim)))
|
|
|
|
p = arr[:2]
|
|
self.assertIsInstance(p, array.ArrayImpl)
|
|
self.assertArraysEqual(p, input_data[:2])
|
|
self.assertLen(s.sharding.device_set, 8)
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
arr.sharding._to_xla_hlo_sharding(arr.ndim),
|
|
s.sharding._to_xla_hlo_sharding(s.ndim)))
|
|
|
|
def test_array_iter_mesh_pspec_sharding_single_device(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test requires >= 2 devices.')
|
|
|
|
single_dev = jax.devices()[1:2]
|
|
mesh = jax.sharding.Mesh(np.array(single_dev), ('x'))
|
|
input_shape = (8, 2)
|
|
arr, input_data = create_array(
|
|
input_shape, jax.sharding.NamedSharding(mesh, P('x')))
|
|
|
|
for i, j in zip(arr, iter(input_data)):
|
|
self.assertArraysEqual(i, j)
|
|
self.assertEqual(i.devices(), {single_dev[0]})
|
|
|
|
def test_array_shards_committed(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test requires >= 2 devices.')
|
|
|
|
x = jnp.array([1, 2, 3])
|
|
for s in x.addressable_shards:
|
|
self.assertEqual(s.data._committed, x._committed)
|
|
self.assertFalse(s.data._committed)
|
|
|
|
y = jax.device_put(x, jax.devices()[1])
|
|
for s in y.addressable_shards:
|
|
self.assertEqual(s.data._committed, y._committed)
|
|
self.assertTrue(s.data._committed)
|
|
|
|
def test_array_jnp_array_copy_multi_device(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, _ = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
|
|
c_arr = jnp.array(arr, copy=True)
|
|
self.assertArraysEqual(arr, c_arr)
|
|
self.assertEqual(arr._committed, c_arr._committed)
|
|
for a, c in safe_zip(arr.addressable_shards, c_arr.addressable_shards):
|
|
self.assertArraysEqual(a.data, c.data)
|
|
self.assertEqual(a.index, c.index)
|
|
self.assertEqual(a.replica_id, c.replica_id)
|
|
self.assertEqual(a.device, c.device)
|
|
self.assertNotEqual(a.data.unsafe_buffer_pointer(),
|
|
c.data.unsafe_buffer_pointer())
|
|
|
|
def test_array_addressable_shards(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
input_shape = (8, 2)
|
|
arr, _ = create_array(
|
|
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
|
|
for a in arr.addressable_shards:
|
|
self.assertIsInstance(a.data, array.ArrayImpl)
|
|
|
|
x = jnp.array([1, 2, 3])
|
|
self.assertIsInstance(x.addressable_data(0), array.ArrayImpl)
|
|
|
|
def test_array_not_hashable(self):
|
|
x = jnp.arange(4)
|
|
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
|
hash(x)
|
|
|
|
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
|
jax.jit(hash)(x)
|
|
|
|
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
|
jax.vmap(hash)(x)
|
|
|
|
def test_shape_dtype_struct_sharding_jit(self):
|
|
mesh = jtu.create_mesh((8,), ('x'))
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
|
|
x_dummy = jax.ShapeDtypeStruct(
|
|
shape=(16,),
|
|
dtype=jnp.dtype('float32'),
|
|
sharding=s)
|
|
|
|
def f(x):
|
|
return x * 2
|
|
|
|
c = jax.jit(f).lower(x_dummy).compile()
|
|
input_shardings, output_shardings = c.input_shardings, c.output_shardings
|
|
self.assertLen(input_shardings, 2)
|
|
self.assertEqual(input_shardings[1], {})
|
|
self.assertEqual(input_shardings[1], {})
|
|
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
|
|
s._to_xla_hlo_sharding(x_dummy.ndim)))
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
|
|
s._to_xla_hlo_sharding(x_dummy.ndim)))
|
|
|
|
def test_shape_dtype_struct_sharding_pjit(self):
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
|
|
def f(x):
|
|
return x * 2.
|
|
|
|
x_dummy = jax.ShapeDtypeStruct(
|
|
shape=(8, 2),
|
|
dtype=jnp.dtype('float32'),
|
|
sharding=s)
|
|
|
|
c = pjit(f).lower(x_dummy).compile()
|
|
input_shardings, output_shardings = c.input_shardings, c.output_shardings
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
|
|
s._to_xla_hlo_sharding(x_dummy.ndim)))
|
|
self.assertTrue(
|
|
op_shardings.are_op_shardings_equal(
|
|
output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
|
|
s._to_xla_hlo_sharding(x_dummy.ndim)))
|
|
|
|
# TODO(skyewm): remove this test when we can remove the workaround manual
|
|
# defragment API
|
|
@jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU
|
|
def test_defragment(self):
|
|
if xb.using_pjrt_c_api():
|
|
self.skipTest("Manual defragment not exposed via PJRT C API")
|
|
|
|
# Create a few arrays
|
|
global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',))
|
|
shape = (8, 2)
|
|
mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',))
|
|
arr1, data = create_array(shape, mpsharding)
|
|
arr2, _ = create_array(shape, mpsharding, data)
|
|
arr3, _ = create_array(shape, mpsharding, data)
|
|
|
|
# Delete one of them
|
|
arr2.delete()
|
|
|
|
# Defragment
|
|
xb.get_backend().defragment()
|
|
|
|
# Sanity check remaining arrays
|
|
self.assertArraysEqual(arr1, data)
|
|
self.assertArraysEqual(arr1 + arr3, data * 2)
|
|
|
|
# TODO(skyewm): check that defragmentation actually happened. I originally
|
|
# thought to do this with unsafe_buffer_pointer(), but that's not always the
|
|
# device memory address. Other ideas include causing enough fragmentation to
|
|
# OOM, and exposing allocator stats in Python.
|
|
|
|
def test_on_device_size_in_bytes(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
a, _ = create_array(
|
|
(8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
|
shard_size = a.addressable_shards[0].data.on_device_size_in_bytes()
|
|
self.assertGreaterEqual(shard_size, 4 * 2)
|
|
self.assertEqual(shard_size * len(a.global_shards),
|
|
a.on_device_size_in_bytes())
|
|
|
|
def test_array_is_ready(self):
|
|
x = jax.device_put(jnp.arange(8.), jax.devices()[0])
|
|
x.is_ready() # doesn't crash
|
|
|
|
def test_process_allgather_single_host(self):
|
|
x = jnp.arange(8.)
|
|
out = multihost_utils.process_allgather(x, tiled=True)
|
|
self.assertEqual(out.shape, x.shape)
|
|
self.assertArraysEqual(out, x)
|
|
|
|
out = multihost_utils.process_allgather(x)
|
|
self.assertEqual(out.shape, (1, x.shape[0]))
|
|
self.assertArraysEqual(out, np.expand_dims(x, axis=0))
|
|
|
|
@jtu.sample_product(
|
|
dtype=jtu.dtypes.all,
|
|
shape=[(), (10), (2, 3)],
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def test_buffer_protocol(self, dtype, shape):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
y = jax.device_put(x)
|
|
if dtype == jax.dtypes.bfloat16:
|
|
with self.assertRaisesRegex(
|
|
BufferError,
|
|
'Buffers of type BF16 are not supported by the Python buffer '
|
|
'protocol.'
|
|
):
|
|
memoryview(y)
|
|
return
|
|
|
|
x_bytes = memoryview(x).tobytes()
|
|
y_bytes = memoryview(y).tobytes()
|
|
self.assertEqual(x_bytes, y_bytes)
|
|
|
|
@jtu.run_on_devices("cpu")
|
|
def test_buffer_protocol_deletion(self):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng((3, 4), np.float32)
|
|
y = jax.device_put(x)
|
|
x_bytes = memoryview(x).tobytes()
|
|
y_view = memoryview(y)
|
|
# The array does not actually get deleted until any external reference is
|
|
# dropped. Arguably we should make calling delete() in these circumstances
|
|
# return an error instead, but that would be a behavior change for existing
|
|
# users.
|
|
y.delete()
|
|
y_bytes = y_view.tobytes()
|
|
self.assertEqual(x_bytes, y_bytes)
|
|
|
|
def test_array_copy_to_host_async(self):
|
|
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = pjit(lambda: jnp.arange(8.),
|
|
out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))()
|
|
self.assertLen(x.sharding.device_set, 4)
|
|
x.copy_to_host_async() # doesn't crash
|
|
self.assertArraysEqual(np.arange(8.), x)
|
|
|
|
def test_array_fully_replicated_shard(self):
|
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
inp_shape = (8, 2)
|
|
arr, inp_data = create_array(
|
|
inp_shape, jax.sharding.NamedSharding(global_mesh, P()))
|
|
fs = arr._fully_replicated_shard()
|
|
self.assertEqual(fs.shape, inp_shape)
|
|
self.assertTrue(dispatch.is_single_device_sharding(fs.sharding))
|
|
self.assertArraysEqual(fs, inp_data)
|
|
self.assertArraysEqual(arr.addressable_data(0), inp_data)
|
|
|
|
def test_shard_array_to_fully_replicated(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
sharding = jax.sharding.NamedSharding(global_mesh, P())
|
|
arr = jnp.arange(16)
|
|
self.assertFalse(arr._committed)
|
|
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
|
out = jax.jit(lambda x: x * 2, in_shardings=sharding)(arr)
|
|
self.assertTrue(out.sharding.is_fully_replicated)
|
|
self.assertArraysEqual(out, arr * 2)
|
|
|
|
def test_fully_replicated_donated_array_is_deleted(self):
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
sharding = jax.sharding.NamedSharding(global_mesh, P())
|
|
arr = jnp.arange(16)
|
|
arr_copy = arr.copy()
|
|
self.assertFalse(arr._committed)
|
|
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
|
out = jax.jit(lambda x: x * 2, in_shardings=sharding, donate_argnums=0)(arr)
|
|
self.assertTrue(out.sharding.is_fully_replicated)
|
|
self.assertArraysEqual(out, arr_copy * 2)
|
|
self.assertTrue(arr.is_deleted())
|
|
|
|
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
|
|
def test_shards_have_correct_dtype(self, dtype):
|
|
x = jnp.ones((), dtype=dtype)
|
|
for shard in x.addressable_shards:
|
|
self.assertEqual(shard.data.dtype, dtype)
|
|
|
|
def test_make_array_from_callback_global_array(self):
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
sharding = jax.sharding.NamedSharding(mesh, P())
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, sharding)
|
|
|
|
out = jax.make_array_from_callback(np_inp.shape, sharding,
|
|
lambda idx: arr[idx])
|
|
self.assertArraysEqual(out, arr)
|
|
self.assertEqual(out.sharding, sharding)
|
|
|
|
sharding2 = NamedSharding(mesh, P('x', 'y'))
|
|
arr2 = jax.device_put(np_inp, sharding2)
|
|
out2 = jax.make_array_from_callback(np_inp.shape, sharding2,
|
|
lambda idx: arr2[idx])
|
|
self.assertArraysEqual(out2, arr2)
|
|
self.assertEqual(out2.sharding, sharding2)
|
|
|
|
def test_make_array_from_process_data_single_host_data_sharding(self):
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
|
data = np.ones((256, 512))
|
|
s = jax.NamedSharding(mesh, P('x'))
|
|
result = jax.make_array_from_process_local_data(s, data)
|
|
self.assertArraysEqual(result, data)
|
|
self.assertEqual(result.sharding, s)
|
|
|
|
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
|
|
@jtu.run_on_devices("gpu")
|
|
def test_pinned_host_npy_value_doesnt_cache(self, dtype):
|
|
# see https://github.com/jax-ml/jax/issues/26216
|
|
d_tensor = jnp.array(0, dtype=dtype)
|
|
d_sharding = d_tensor.sharding
|
|
h_sharding = d_sharding.with_memory_kind("pinned_host")
|
|
h_tensor = jax.device_put(d_tensor, h_sharding)
|
|
np.array(h_tensor)
|
|
self.assertIsNone(h_tensor._npy_value)
|
|
|
|
|
|
class ShardingTest(jtu.JaxTestCase):
|
|
|
|
def test_mesh_pspec_sharding_interface(self):
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
pspec = P('y', 'x')
|
|
global_shape = (8, 4)
|
|
mp_sharding = jax.sharding.NamedSharding(mesh, pspec)
|
|
di_map = mp_sharding.devices_indices_map(global_shape)
|
|
hlo_sharding = mp_sharding._to_xla_hlo_sharding(len(global_shape))
|
|
device_assignment = mp_sharding._device_assignment
|
|
|
|
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
|
|
self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
|
|
allow_object_dtype=True)
|
|
self.assertTrue(hlo_sharding.is_tiled())
|
|
self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
|
|
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
|
|
[0, 2, 4, 6, 1, 3, 5, 7])
|
|
|
|
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
|
def test_util_clear_cache(self):
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
|
s = NamedSharding(mesh, P())
|
|
s.devices_indices_map((8,))
|
|
jax.clear_caches()
|
|
s.devices_indices_map((8,))
|
|
c = common_devices_indices_map.cache_info()
|
|
self.assertEqual(c.currsize, 1)
|
|
|
|
@parameterized.named_parameters(
|
|
("mesh_x_y", P("x", "y")),
|
|
("mesh_x", P("x")),
|
|
("mesh_y", P("y")),
|
|
("mesh_none_y", P(None, "y")),
|
|
("mesh_none_x", P(None, "x")),
|
|
("mesh_xy", P(("x", "y"))),
|
|
("mesh_fully_replicated", P()),
|
|
)
|
|
def test_op_sharding_indices(self, pspec):
|
|
shape = (8, 4)
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
mps = jax.sharding.NamedSharding(mesh, pspec)
|
|
ops = jax.sharding.GSPMDSharding(
|
|
list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape)))
|
|
self.assertDictEqual(
|
|
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
|
|
|
|
@parameterized.named_parameters(
|
|
("mesh_x_y", P("x", "y"), (2, 2)),
|
|
("mesh_x", P("x"), (2, 4)),
|
|
("mesh_y", P("y"), (4, 4)),
|
|
("mesh_none_y", P(None, "y"), (8, 2)),
|
|
("mesh_none_x", P(None, "x"), (8, 1)),
|
|
("mesh_xy", P(("x", "y")), (1, 4)),
|
|
("mesh_fully_replicated", P(), (8, 4)),
|
|
)
|
|
def test_shard_shape(self, pspec, expected_shard_shape):
|
|
shape = (8, 4)
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
mps = jax.sharding.NamedSharding(mesh, pspec)
|
|
self.assertEqual(mps.shard_shape(shape), expected_shard_shape)
|
|
|
|
def test_uneven_shard_error(self):
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Sharding.*implies that array axis 1 is partitioned 2 times, but the "
|
|
r"dimension size is 3 \(full shape: \(8, 3\), per-dimension tiling "
|
|
r"factors: \[4, 2\] should evenly divide the shape\)"):
|
|
mps.shard_shape((8, 3))
|
|
|
|
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
|
def test_pmap_sharding_hash_eq(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test needs >= 2 devices.')
|
|
|
|
shape = (2, 2)
|
|
num_elements = math.prod(shape)
|
|
inp_data = np.arange(num_elements).reshape(shape)
|
|
out = jax.pmap(lambda x: x)(inp_data)
|
|
self.assertIsInstance(out.sharding, jax.sharding.PmapSharding)
|
|
# Populate the device_indices_map cache.
|
|
_ = out.sharding.devices_indices_map(shape)
|
|
cache_info1 = pmap_sharding_devices_indices_map.cache_info()
|
|
|
|
inp_data2 = np.arange(num_elements, num_elements + num_elements).reshape(shape)
|
|
out2 = jax.pmap(lambda x: x)(inp_data2)
|
|
# Populate the device_indices_map cache.
|
|
_ = out2.sharding.devices_indices_map(shape)
|
|
cache_info2 = pmap_sharding_devices_indices_map.cache_info()
|
|
|
|
self.assertGreater(cache_info2.hits, cache_info1.hits + 1)
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
def test_is_compatible_error(self):
|
|
shape = (8, 2)
|
|
mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
|
|
mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Sharding NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\)"
|
|
' is only valid for values of rank at least 4, but was applied to a'
|
|
' value of rank 2'):
|
|
mps.check_compatible_aval(shape)
|
|
|
|
def test_is_subclass(self):
|
|
# array version of api_test.py::APITest::test_is_subclass
|
|
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
|
|
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
|
|
|
|
def test_gspmd_sharding_repr(self):
|
|
op = xc.OpSharding()
|
|
op.type = xc.OpSharding.Type.OTHER
|
|
op.tile_assignment_dimensions = [4, 1, 2]
|
|
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
op.replicate_on_last_tile_dim = True
|
|
s = jax.sharding.GSPMDSharding(jax.devices(), op)
|
|
# memory kind also appears in the repr but only for TPU.
|
|
self.assertIn(
|
|
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
|
|
'last_tile_dim_replicate}', repr(s))
|
|
|
|
op2 = xc.OpSharding()
|
|
op2.type = xc.OpSharding.Type.REPLICATED
|
|
s2 = jax.sharding.GSPMDSharding(jax.devices(), op2)
|
|
# memory kind also appears in the repr but only for TPU.
|
|
self.assertIn('GSPMDSharding({replicated}', repr(s2))
|
|
|
|
def test_positional_sharding_fully_replicated(self):
|
|
sharding = PositionalSharding(jax.devices())
|
|
jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash
|
|
|
|
@parameterized.named_parameters(
|
|
("mesh_x_y", P("x", "y"), (4, 2), (), False),
|
|
("mesh_x", P("x"), (4, 2), (1,), False),
|
|
("mesh_y", P("y"), (4, 2), (0,), True),
|
|
("mesh_none_y", P(None, "y"), (4, 2), (0,), False),
|
|
("mesh_none_x", P(None, "x"), (4, 2), (1,), True),
|
|
("mesh_xy", P(("x", "y")), (8, 1), (), False),
|
|
("mesh_fully_replicated", P(), (4, 2), None, False),
|
|
)
|
|
def test_positional_sharding_op_sharding_lowering(
|
|
self, pspec, shape, axes, transpose):
|
|
value_shape = (8, 4)
|
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
mps = jax.sharding.NamedSharding(mesh, pspec)
|
|
devices = jax.local_devices()[:8] # Taking up to 8 devices
|
|
|
|
devices_sharding = jax.sharding.PositionalSharding(devices)
|
|
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
|
|
if transpose:
|
|
devices_sharding = devices_sharding.T
|
|
|
|
op1 = mps._to_xla_hlo_sharding(len(value_shape))
|
|
op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape))
|
|
|
|
self.assertEqual(mps.shard_shape(value_shape),
|
|
devices_sharding.shard_shape(value_shape))
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
|
|
|
def test_positional_sharding_aval_compatible(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Requires >=2 devices')
|
|
sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count())
|
|
x = jax.random.uniform(jax.random.key(42), (256, 20, 1000))
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Sharding PositionalSharding.*is only valid for values of rank 2, but'
|
|
' was applied to a value of rank 3'):
|
|
jax.lax.with_sharding_constraint(x, sharding)
|
|
|
|
@parameterized.named_parameters(
|
|
("2d_mesh_x_y", (4, 2), P("x", "y")),
|
|
("2d_mesh_x", (4, 2), P("x")),
|
|
("2d_mesh_y", (4, 2), P("y")),
|
|
("2d_mesh_none_y", (4, 2), P(None, "y")),
|
|
("2d_mesh_none_x", (4, 2), P(None, "x")),
|
|
("2d_mesh_xy", (4, 2), P(("x", "y"))),
|
|
("2d_mesh_none_xy", (4, 2), P(None, ("x", "y"))),
|
|
("2d_mesh_x_none", (2, 1), P(('x',), None)),
|
|
("2d_mesh_fully_replicated", (4, 2), P()),
|
|
("3d_mesh_none_none_z", (2, 2, 2), P(None, None, 'z')),
|
|
("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)),
|
|
("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)),
|
|
("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))),
|
|
("3d_mesh_x_none_yz", (2, 2, 2), P('x', None, ('y', 'z'))),
|
|
("3d_mesh_none_x_yz", (2, 2, 2), P(None, 'x', ('y', 'z'))),
|
|
("3d_mesh_xy_z", (2, 2, 2), P(('x', 'y'), 'z')),
|
|
("3d_mesh_xy_none_z", (2, 2, 2), P(('x', 'y'), None, 'z')),
|
|
("3d_mesh_x_y_z", (2, 2, 2), P('x', 'y', 'z')),
|
|
("3d_mesh_xz_y", (2, 2, 2), P(('x', 'z'), 'y')),
|
|
("3d_mesh_xz_none_y", (2, 2, 2), P(('x', 'z'), None, 'y')),
|
|
("3d_mesh_y_none_xz", (2, 2, 2), P('y', None, ('x', 'z'))),
|
|
("3d_mesh_none_y_xz", (2, 2, 2), P(None, 'y', ('x', 'z'))),
|
|
("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')),
|
|
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
|
|
("3d_mesh_x_none_none", (2, 1, 1), P('x', None, None)),
|
|
)
|
|
def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
|
|
ndim = len(mesh_shape)
|
|
mesh = jtu.create_mesh(
|
|
mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z'))
|
|
mps = jax.sharding.NamedSharding(mesh, pspec)
|
|
original_op_sharding = mps._to_xla_hlo_sharding(ndim)
|
|
ps = _op_sharding_to_pos_sharding(original_op_sharding,
|
|
mps._device_assignment)
|
|
out_op_sharding = ps._to_xla_hlo_sharding(ndim)
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(
|
|
original_op_sharding, out_op_sharding))
|
|
|
|
@parameterized.named_parameters(
|
|
("2d_mesh_x", (1, 1), P("x", "y")),
|
|
("2d_mesh_x_y", (4, 2), P("x", "y")),
|
|
("2d_mesh_empty", (2, 1), P()),
|
|
("2d_mesh_p_none", (2, 1), P(None)),
|
|
("2d_mesh_none_none", (2, 1), P(None, None)),
|
|
("2d_mesh_tuple_empty", (2, 1), P((),)),
|
|
("2d_mesh_x_none", (2, 1), P(('x',), None)),
|
|
("2d_mesh_xy_none", (2, 1), P(('x', 'y'), None)),
|
|
("2d_mesh_x_tuple_empty", (2, 1), P('x', (), (), ())),
|
|
("2d_mesh_3_tuple_empty", (2, 1), P((), (), ())),
|
|
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
|
|
("3d_mesh2_x_y_none", (1, 1, 4), P('x', 'y', None)),
|
|
("3d_mesh2_xy_none", (1, 1, 4), P(('x', 'y'), None)),
|
|
)
|
|
def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
|
|
if len(mesh_shape) == 2:
|
|
axis_names = ('x', 'y')
|
|
elif len(mesh_shape) == 3:
|
|
axis_names = ('x', 'y', 'z')
|
|
else:
|
|
axis_names = ('x',)
|
|
mesh = jtu.create_mesh(mesh_shape, axis_names)
|
|
mps = jax.sharding.NamedSharding(mesh, pspec)
|
|
shape = (8, 2, 4)
|
|
mps_op_sharding = mps._to_xla_hlo_sharding(len(shape))
|
|
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
|
|
self.assertEqual(mps.is_fully_replicated, ops_ifr)
|
|
|
|
ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment)
|
|
self.assertEqual(ps.is_fully_replicated,
|
|
op_shardings.is_op_sharding_replicated(
|
|
ps._to_xla_hlo_sharding(len(shape))))
|
|
|
|
def test_devices_sharding_respects_init_mesh_shape(self):
|
|
value_shape = (8, 4)
|
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
|
|
devices_sharding = jax.sharding.PositionalSharding(mesh.devices)
|
|
|
|
op1 = mps._to_xla_hlo_sharding(len(value_shape))
|
|
op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape))
|
|
|
|
self.assertEqual(mps.shard_shape(value_shape),
|
|
devices_sharding.shard_shape(value_shape))
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
|
|
|
def test_pmap_sharding_repr(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test needs >= 2 devices.')
|
|
out = jax.pmap(lambda x: x)(jnp.arange(2.))
|
|
str(out.sharding) # doesn't crash
|
|
repr(out.sharding) # doesn't crash
|
|
|
|
def test_positional_sharding_repr(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest('Test needs >= 2 devices.')
|
|
s = jax.sharding.PositionalSharding(jax.devices()).reshape(jax.device_count(), 1)
|
|
repr(s) # doesn't crash
|
|
str(s) # doesn't crash
|
|
|
|
def test_pspec_tuple(self):
|
|
pspec = P('x', 'y', 'z')
|
|
self.assertEqual(pspec, ('x', 'y', 'z'))
|
|
self.assertEqual(pspec.index('z'), 2)
|
|
self.assertEqual(hash(P(None, 'x', 'y', 'z')), hash(P((), 'x', 'y', 'z')))
|
|
|
|
@parameterized.named_parameters(
|
|
('sharded_dim_0', (4, 2), 0),
|
|
('sharded_dim_1_0', (4, 2), 1),
|
|
('sharded_dim_2', (4, 2, 4), 2),
|
|
('sharded_dim_1_1', (2, 4), 1)
|
|
)
|
|
def test_default_pmap_sharding(self, shape, sharded_dim):
|
|
if jax.device_count() < 4:
|
|
self.skipTest('Test needs >= 4 devices.')
|
|
ps = jax.sharding.PmapSharding.default(shape, sharded_dim)
|
|
|
|
inp = jnp.arange(math.prod(shape)).reshape(shape)
|
|
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
|
|
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
|
|
|
|
self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment)
|
|
self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec)
|
|
|
|
def test_default_pmap_sharding_with_devices(self):
|
|
if jax.device_count() < 4:
|
|
self.skipTest('Test needs >= 4 devices.')
|
|
|
|
devs = jax.devices()
|
|
new_order = (devs[0], devs[3], devs[2], devs[1])
|
|
ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order)
|
|
self.assertEqual(ps._device_assignment, new_order)
|
|
|
|
def test_default_pmap_sharding_replicated(self):
|
|
x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32)
|
|
x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(x)
|
|
ps = jax.sharding.PmapSharding.default(
|
|
shape=(8,), sharded_dim=None,
|
|
devices=jax.local_devices())
|
|
self.assertEqual(x.sharding, ps)
|
|
|
|
def test_mesh_repr(self):
|
|
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
|
|
mesh_repr = repr(mesh)
|
|
self.assertIn('device_ids', mesh_repr)
|
|
self.assertIn('axis_names', mesh_repr)
|
|
|
|
def test_are_shardings_equivalent(self):
|
|
mesh = jtu.create_mesh((1,), ('x'))
|
|
mesh2 = jtu.create_mesh((2, 1), ('x', 'y'))
|
|
|
|
s1 = jax.sharding.NamedSharding(mesh, P('x'))
|
|
s2 = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
|
|
self.assertTrue(s1.is_equivalent_to(s2, 2))
|
|
|
|
s3 = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())).sharding
|
|
s4 = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())).sharding
|
|
self.assertTrue(s3.is_equivalent_to(s4, 2))
|
|
|
|
self.assertFalse(s1.is_equivalent_to(s3, 2))
|
|
self.assertFalse(s2.is_equivalent_to(s3, 2))
|
|
|
|
s5 = jax.sharding.NamedSharding(mesh2, P('x', 'y'))
|
|
|
|
op1 = xc.OpSharding()
|
|
op1.type = xc.OpSharding.Type.REPLICATED
|
|
s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1)
|
|
|
|
s7 = jax.sharding.GSPMDSharding(jax.devices(), op1)
|
|
|
|
# The OpSharding is replicated but the Sharding itself are on different
|
|
# devices.
|
|
self.assertFalse(s6.is_equivalent_to(s7, 2))
|
|
|
|
op2 = xc.OpSharding()
|
|
op2.type = xc.OpSharding.Type.OTHER
|
|
op2.tile_assignment_devices = [0, 1]
|
|
op2.tile_assignment_dimensions = [2, 1]
|
|
s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2)
|
|
|
|
self.assertTrue(s1.is_equivalent_to(s6, 2))
|
|
self.assertTrue(s5.is_equivalent_to(s8, 2))
|
|
self.assertFalse(s5.is_equivalent_to(s2, 2))
|
|
|
|
s9 = jax.sharding.NamedSharding(mesh2, P('y'))
|
|
|
|
op3 = xc.OpSharding()
|
|
op3.type = xc.OpSharding.Type.OTHER
|
|
op3.tile_assignment_devices = [0, 1]
|
|
op3.tile_assignment_dimensions = [1, 1, 2]
|
|
op3.replicate_on_last_tile_dim = True
|
|
s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3)
|
|
|
|
self.assertTrue(s9.is_equivalent_to(s10, 2))
|
|
|
|
def test_devices_indices_map_good_error_message(self):
|
|
shape = (1, 2)
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Sharding.*implies that array axis 0 is partitioned 2 times, but the "
|
|
"dimension size is 1"):
|
|
s.devices_indices_map(shape)
|
|
|
|
def test_scalar_input_wrong_pspec(self):
|
|
mesh = jtu.create_mesh((1, ), ('x'))
|
|
shape = ()
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"For scalars the PartitionSpec should be P()"):
|
|
s.check_compatible_aval(shape)
|
|
|
|
def test_mesh_caching_during_construction(self):
|
|
if jax.device_count() < 2:
|
|
raise unittest.SkipTest("Requires >=2 devices")
|
|
mesh1 = jax.sharding.Mesh(jax.devices(), 'x')
|
|
mesh2 = jax.sharding.Mesh(jax.devices(), 'x')
|
|
|
|
self.assertIs(mesh1, mesh2)
|
|
|
|
def test_mesh_str(self):
|
|
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
|
self.assertEqual(
|
|
str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto))"
|
|
)
|
|
|
|
def test_make_array_from_callback_error(self):
|
|
mesh_shape = (2, 3)
|
|
global_shape = tuple(np.square(mesh_shape))
|
|
mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True)
|
|
pspec = P('x', 'y')
|
|
sharding = jax.sharding.NamedSharding(mesh, pspec)
|
|
n = math.prod(global_shape)
|
|
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
|
|
|
|
def f(arr):
|
|
return array.make_array_from_callback(arr.shape, sharding, lambda i: arr[i])
|
|
|
|
out = f(global_x)
|
|
self.assertEqual(out.shape, global_shape)
|
|
|
|
msg = "jax.make_array_from_callback cannot be called within a traced context"
|
|
with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, msg):
|
|
jax.jit(f)(global_x)
|
|
|
|
def test_make_array_from_single_device_arrays_error(self):
|
|
x = jnp.arange(10)
|
|
sharding = x.sharding
|
|
|
|
def f(x):
|
|
return jax.make_array_from_single_device_arrays(x.shape, sharding, [x])
|
|
|
|
msg = "jax.make_array_from_single_device_arrays requires a list of concrete arrays"
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
jax.jit(f)(x)
|
|
|
|
def test_make_array_from_single_device_arrays_nonlist_error(self):
|
|
x = jnp.arange(10)
|
|
sharding = x.sharding
|
|
|
|
def f(x):
|
|
return jax.make_array_from_single_device_arrays(x.shape, sharding, x)
|
|
|
|
msg = "jax.make_array_from_single_device_arrays `arrays` argument"
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
jax.jit(f)(x)
|
|
|
|
def test_make_array_from_single_device_arrays_bad_inputs(self):
|
|
x = jnp.arange(10)
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
x = jax.device_put(x, s)
|
|
|
|
msg = ("When making an array from single-device arrays the input arrays "
|
|
"must have one shard each. An argument array had 2 shard\\(s\\).")
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
jax.make_array_from_single_device_arrays(x.shape, s, [x, x])
|
|
|
|
def test_gspmd_sharding_hash_eq(self):
|
|
mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))
|
|
ns = NamedSharding(mesh, P('x', 'y', 'z'))
|
|
|
|
x1 = GSPMDSharding(mesh._flat_devices_tuple, ns._to_xla_hlo_sharding(3))
|
|
x2 = GSPMDSharding.get_replicated(mesh._flat_devices_tuple)
|
|
|
|
self.assertEqual(x1, x2)
|
|
self.assertEqual(hash(x1), hash(x2))
|
|
|
|
def test_device_attr(self):
|
|
# For single-device arrays, x.device returns the device
|
|
x = jnp.ones((2, 10))
|
|
self.assertEqual(x.device, list(x.devices())[0])
|
|
|
|
# For sharded arrays, x.device returns the sharding
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
|
x = jax.device_put(x, sharding)
|
|
self.assertEqual(x.device, sharding)
|
|
|
|
def test_to_device(self):
|
|
device = jax.devices()[-1]
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
|
|
|
x = jnp.ones((2, 10))
|
|
|
|
x_device = x.to_device(device)
|
|
x_sharding = x.to_device(sharding)
|
|
|
|
self.assertEqual(x_device.device, device)
|
|
self.assertEqual(x_sharding.device, sharding)
|
|
|
|
def test_mesh_with_axis_name_none(self):
|
|
with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'):
|
|
jax.sharding.Mesh(jax.devices(), (None, 'x'))
|
|
|
|
def test_mesh_axis_types_mismatch(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Number of axis names should match the number of axis_types'):
|
|
jtu.create_mesh((2, 1), ('x', 'y'),
|
|
axis_types=jax.sharding.AxisTypes.Auto)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Number of axis names should match the number of axis_types'):
|
|
jax.sharding.AbstractMesh((2, 1), ('x', 'y'),
|
|
axis_types=jax.sharding.AxisTypes.Auto)
|
|
|
|
def test_make_mesh_axis_types(self):
|
|
Auto, Explicit, Manual = AxisTypes.Auto, AxisTypes.Explicit, AxisTypes.Manual
|
|
|
|
mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
|
|
mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
|
|
self.assertEqual(mesh1, mesh2)
|
|
|
|
mesh = jax.make_mesh((1, 1), ('x', 'y'))
|
|
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Auto: ('x', 'y')})
|
|
|
|
mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
|
|
axis_types=(Explicit, Auto, Manual))
|
|
self.assertDictEqual(
|
|
mesh._axis_types_dict, {AxisTypes.Auto: ('y',), AxisTypes.Explicit: ('x',),
|
|
AxisTypes.Manual: ('z',)})
|
|
|
|
mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
|
|
axis_types=(Explicit, Explicit, Manual))
|
|
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Explicit: ('x', 'y'),
|
|
AxisTypes.Manual: ('z',)})
|
|
|
|
mesh = jax.make_mesh((1, 1), ('x', 'y'), axis_types=(Explicit, Explicit))
|
|
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Explicit: ('x', 'y')})
|
|
|
|
mesh = jax.make_mesh((1,), 'model', axis_types=Manual)
|
|
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Manual: ('model',)})
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Number of axis names should match the number of axis_types'):
|
|
jax.make_mesh((1, 1), ('data', 'model'), axis_types=Explicit)
|
|
|
|
mesh1 = jax.make_mesh((1, 1, 1, 1, 1), ('a', 'b', 'c', 'd', 'e'),
|
|
axis_types=(Explicit, Auto, Auto, Explicit, Explicit))
|
|
mesh2 = jax.make_mesh((1, 1, 1, 1, 1), ('a', 'b', 'c', 'd', 'e'),
|
|
axis_types=(Explicit, Auto, Auto, Explicit, Auto))
|
|
self.assertNotEqual(mesh1, mesh2)
|
|
self.assertNotEqual(hash(mesh1), hash(mesh2))
|
|
|
|
|
|
@jtu.with_config(jax_use_shardy_partitioner=True)
|
|
class ShardyShardingTest(jtu.JaxTestCase):
|
|
|
|
def test_long_axis_names(self):
|
|
mesh = jtu.create_mesh((2, 2, 2), ('sequence', 'data', 'model'))
|
|
s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model'))
|
|
sdy_sharding = s._to_sdy_sharding(3)
|
|
self.assertEqual(
|
|
sdy_sharding,
|
|
SdyArraySharding(
|
|
mesh.shape_tuple,
|
|
[SdyDimSharding(
|
|
('sequence', 'data'), True),
|
|
SdyDimSharding(('model',), True),
|
|
SdyDimSharding([], True)]))
|
|
with ir.Context() as ctx:
|
|
dialects.sdy.register_dialect(ctx)
|
|
self.assertEqual(
|
|
str(sdy_sharding.build()),
|
|
'#sdy.sharding<mesh<["sequence"=2, "data"=2, "model"=2]>,'
|
|
' [{"sequence", "data"}, {"model"}, {}]>',
|
|
)
|
|
|
|
def test_unconstrained(self):
|
|
mesh = jtu.create_mesh((8,), ('x',))
|
|
s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x'))
|
|
sdy_sharding = s._to_sdy_sharding(3)
|
|
self.assertEqual(
|
|
sdy_sharding,
|
|
SdyArraySharding(
|
|
mesh.shape_tuple,
|
|
[SdyDimSharding([], True),
|
|
SdyDimSharding([], False),
|
|
SdyDimSharding(('x',), True)]))
|
|
with ir.Context() as ctx:
|
|
dialects.sdy.register_dialect(ctx)
|
|
self.assertEqual(
|
|
str(sdy_sharding.build()),
|
|
'#sdy.sharding<mesh<["x"=8]>, [{}, {?}, {"x"}]>')
|
|
|
|
|
|
class RngShardingTest(jtu.JaxTestCase):
|
|
# tests that the PRNGs are automatically sharded as expected
|
|
|
|
@parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5))
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_random_bits_is_pure_map_1d(self, num_devices):
|
|
@jax.jit
|
|
def f(x):
|
|
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
|
|
32, x.shape)
|
|
return bits + x
|
|
|
|
mesh = jtu.create_mesh((num_devices,), ('x',), iota_order=True)
|
|
s = jax.sharding.NamedSharding(mesh, P('x'))
|
|
|
|
n = num_devices ** 2
|
|
global_x = jnp.arange(n).astype('uint32')
|
|
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
|
|
|
|
# check computation is fully partitioned and without any communication
|
|
with jax.threefry_partitionable(True):
|
|
unopt_txt = f.lower(x).as_text(dialect='hlo')
|
|
opt_txt = f.lower(x).compile().as_text()
|
|
self.assertIn( f'[{n}]', unopt_txt)
|
|
self.assertNotIn(f'[{n}]', opt_txt)
|
|
self.assertNotIn('all-reduce', opt_txt)
|
|
self.assertNotIn('collective-permute', opt_txt)
|
|
|
|
# check against single-device reference
|
|
y = f(x)
|
|
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
|
|
self.assertArraysEqual(y, y_ref1)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": f"_{mesh_shape}_{pspec}",
|
|
"mesh_shape": mesh_shape, "pspec": pspec}
|
|
for mesh_shape in [(3, 2), (4, 2), (2, 3)]
|
|
for pspec in [P('x', None), P(None, 'y'), P('x', 'y')])
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec):
|
|
@jax.jit
|
|
def f(x):
|
|
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
|
|
32, x.shape)
|
|
return bits + x
|
|
|
|
global_shape = tuple(np.square(mesh_shape))
|
|
|
|
mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True)
|
|
s = jax.sharding.NamedSharding(mesh, pspec)
|
|
|
|
n = math.prod(global_shape)
|
|
global_x = np.arange(n).astype('uint32').reshape(global_shape)
|
|
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
|
|
|
|
# check computation is fully partitioned and without any communication
|
|
with jax.threefry_partitionable(True):
|
|
unopt_txt = f.lower(x).as_text(dialect='hlo')
|
|
opt_txt = f.lower(x).compile().as_text()
|
|
global_shape_fmt = ','.join(str(x) for x in global_shape)
|
|
self.assertIn( f'[{global_shape_fmt}]', unopt_txt)
|
|
self.assertNotIn(f'[{global_shape_fmt}]', opt_txt)
|
|
self.assertNotIn('all-reduce', opt_txt)
|
|
self.assertNotIn('collective-permute', opt_txt)
|
|
|
|
# check against single-device reference
|
|
y = f(x)
|
|
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
|
|
self.assertArraysEqual(y, y_ref1)
|
|
|
|
def test_empty_mesh_creation(self):
|
|
mesh = jax.sharding.Mesh(devices=np.empty((), dtype=object), axis_names=[])
|
|
self.assertTrue(mesh.empty)
|
|
self.assertEqual(mesh.size, 0)
|
|
|
|
abstract_mesh = mesh.abstract_mesh
|
|
self.assertTrue(abstract_mesh.empty)
|
|
self.assertEqual(abstract_mesh.size, 0)
|
|
|
|
abstract_mesh2 = jax.sharding.AbstractMesh((), ())
|
|
self.assertTrue(abstract_mesh2.empty)
|
|
self.assertEqual(abstract_mesh2.size, 0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|