mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Rename jtu.create_global_mesh
to jtu.create_mesh
and use jax.make_mesh
inside jtu.create_mesh
to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
This commit is contained in:
parent
5c0ee1a3e9
commit
e1b497078e
@ -1375,15 +1375,16 @@ def with_and_without_mesh(f):
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
def create_global_mesh(mesh_shape, axis_names):
|
||||
def create_mesh(mesh_shape, axis_names, iota_order=False):
|
||||
size = math.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||
devices = sorted(jax.devices(), key=lambda d: d.id)
|
||||
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
||||
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
return global_mesh
|
||||
|
||||
if iota_order:
|
||||
devices = sorted(jax.devices(), key=lambda d: d.id)
|
||||
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
||||
return jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
else:
|
||||
return jax.make_mesh(mesh_shape, axis_names)
|
||||
|
||||
class _cached_property:
|
||||
null = object()
|
||||
|
@ -52,7 +52,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices('cpu')
|
||||
def test_memory_consumption(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((2, 4), ('x', 'y'))
|
||||
inp_shape = (2_048, 4_096)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
@ -97,7 +97,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
tm.stop()
|
||||
|
||||
def test_memory_consumption_for_save(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 1), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((1, 1), ('x', 'y'))
|
||||
inp_shape = (16 * 1024, 16 * 1024)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
@ -132,7 +132,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
tm.stop()
|
||||
|
||||
def test_checkpointing_with_path_variant(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
@ -164,7 +164,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
def test_checkpointing_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
@ -188,7 +188,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
# Third Array
|
||||
def cb3(_):
|
||||
return np.array([], dtype=np.float32)
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
global_mesh1d = jtu.create_mesh((8,), ('x',))
|
||||
a3 = array.make_array_from_callback(
|
||||
(0,), NamedSharding(global_mesh1d, P(None)), cb3)
|
||||
ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path)
|
||||
@ -232,7 +232,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_ocdbt_transaction(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
@ -262,7 +262,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
def cb3(_):
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
global_mesh1d = jtu.create_mesh((8,), ('x',))
|
||||
a3 = array.make_array_from_callback(
|
||||
(0,), NamedSharding(global_mesh1d, P(None)), cb3
|
||||
)
|
||||
@ -327,7 +327,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.product(input_dtype=[np.int32, jnp.bfloat16])
|
||||
def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True)
|
||||
global_input_shape = (8, 2)
|
||||
num = math.prod(global_input_shape)
|
||||
|
||||
@ -349,7 +349,8 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
||||
manager.wait_until_finished()
|
||||
|
||||
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
||||
ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True),
|
||||
P('x', 'y'))
|
||||
|
||||
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
|
||||
[np.float32])
|
||||
@ -375,7 +376,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.product(input_dtype=[jnp.int4, jnp.int8])
|
||||
def test_checkpointing_with_int4(self, input_dtype):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True)
|
||||
global_input_shape = (8, 2)
|
||||
num = math.prod(global_input_shape)
|
||||
|
||||
@ -397,7 +398,8 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
||||
manager.wait_until_finished()
|
||||
|
||||
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
||||
ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True),
|
||||
P('x', 'y'))
|
||||
|
||||
target_dtype = jnp.dtype('int4')
|
||||
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
|
||||
@ -424,7 +426,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(l.data, global_input_data.astype(target_dtype))
|
||||
|
||||
def test_checkpointing_scalar_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_mesh = jtu.create_mesh((2,), ('x'))
|
||||
global_input_shape = ()
|
||||
data = np.array(4)
|
||||
s = NamedSharding(global_mesh, P(None))
|
||||
@ -441,7 +443,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
||||
manager.wait_until_finished()
|
||||
|
||||
ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))
|
||||
ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None))
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[ds],
|
||||
@ -454,7 +456,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
def test_deserialize_tensorstore_array_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_mesh = jtu.create_mesh((2,), ('x'))
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
m1, = serialization.run_deserialization(
|
||||
@ -550,7 +552,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
if not jtu.test_device_matches(['tpu']):
|
||||
self.skipTest('Layouts are only supported on TPUs')
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(32).reshape(8, 4)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
@ -81,7 +81,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
("mesh_fully_replicated", P()),
|
||||
)
|
||||
def test_jax_array_value(self, mesh_axes):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, global_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes))
|
||||
@ -121,7 +121,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids, expected_is_fully_replicated):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True)
|
||||
global_input_shape = (8, 2)
|
||||
s = jax.sharding.NamedSharding(global_mesh, mesh_axes)
|
||||
arr, global_input_data = create_array(global_input_shape, s)
|
||||
@ -148,7 +148,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
def test_addressable_data(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 2)
|
||||
s = jax.sharding.NamedSharding(global_mesh, P(None))
|
||||
arr, inp_data = create_array(shape, s)
|
||||
@ -156,7 +156,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(inp_data, arr.addressable_data(i))
|
||||
|
||||
def test_array_delete(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, _ = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -174,7 +174,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
_ = x + 1
|
||||
|
||||
def test_multi_device_array_usage_after_delete(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 2)
|
||||
arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32),
|
||||
jax.sharding.NamedSharding(global_mesh, P('x')))
|
||||
@ -205,14 +205,14 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsNone(arr._arrays)
|
||||
|
||||
def test_array_device_get(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
self.assertArraysEqual(jax.device_get(arr), input_data)
|
||||
|
||||
def test_repr(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, _ = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -254,7 +254,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
|
||||
|
||||
def test_array_sharded_astype(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -272,7 +272,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
|
||||
|
||||
def test_array_delete_idempotent(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x')))
|
||||
|
||||
arr.delete()
|
||||
@ -282,7 +282,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_sharded_add(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
a, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -296,7 +296,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(i.data, expected[i.index])
|
||||
|
||||
def test_sharded_zeros_like(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
a, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -318,7 +318,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires more than 4 devices')
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
||||
devices = jax.local_devices()[:2] # Taking up to 2 devices
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -342,7 +342,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires more than 4 devices')
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
||||
# sharding device ids = {0, 1}
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -378,7 +378,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
if xc._version <= 274:
|
||||
self.skipTest('Test requires jaxlib version 275')
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
||||
# Sharding device ids = {0, 1}
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -401,7 +401,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
|
||||
shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
inp_data = np.arange(5)
|
||||
|
||||
@ -415,7 +415,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_mismatch_dtype(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
||||
indices = s.devices_indices_map(shape)
|
||||
@ -452,7 +452,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(i, j)
|
||||
|
||||
def test_array_iter_mesh_pspec_sharding_multi_device(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -462,7 +462,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(i, j)
|
||||
|
||||
def test_array_iter_replicated_multi_device(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
|
||||
@ -477,7 +477,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
i.sharding._to_xla_hlo_sharding(i.ndim)))
|
||||
|
||||
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -496,7 +496,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape)
|
||||
self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding)
|
||||
|
||||
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
global_mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
input_shape = (4, 4, 2)
|
||||
arr, np_inp = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z')))
|
||||
@ -523,7 +523,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
_check(arr[1], np_inp[1], (2, 1))
|
||||
|
||||
def test_array_getitem_replicated_multi_device(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
|
||||
@ -575,7 +575,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertTrue(s.data._committed)
|
||||
|
||||
def test_array_jnp_array_copy_multi_device(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, _ = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -592,7 +592,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
c.data.unsafe_buffer_pointer())
|
||||
|
||||
def test_array_addressable_shards(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
arr, _ = create_array(
|
||||
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
@ -620,7 +620,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
check_tracer_hash(x)
|
||||
|
||||
def test_shape_dtype_struct_sharding_jit(self):
|
||||
mesh = jtu.create_global_mesh((8,), ('x'))
|
||||
mesh = jtu.create_mesh((8,), ('x'))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
x_dummy = jax.ShapeDtypeStruct(
|
||||
@ -647,7 +647,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
s._to_xla_hlo_sharding(x_dummy.ndim)))
|
||||
|
||||
def test_shape_dtype_struct_sharding_pjit(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
def f(x):
|
||||
@ -677,7 +677,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.skipTest("Manual defragment not exposed via PJRT C API")
|
||||
|
||||
# Create a few arrays
|
||||
global_mesh = jtu.create_global_mesh((jax.local_device_count(),), ('x',))
|
||||
global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',))
|
||||
shape = (8, 2)
|
||||
mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',))
|
||||
arr1, data = create_array(shape, mpsharding)
|
||||
@ -700,7 +700,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
# OOM, and exposing allocator stats in Python.
|
||||
|
||||
def test_on_device_size_in_bytes(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
a, _ = create_array(
|
||||
(8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
|
||||
shard_size = a.addressable_shards[0].data.on_device_size_in_bytes()
|
||||
@ -756,7 +756,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(x_bytes, y_bytes)
|
||||
|
||||
def test_array_copy_to_host_async(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = pjit(lambda: jnp.arange(8.),
|
||||
out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))()
|
||||
self.assertLen(x.sharding.device_set, 4)
|
||||
@ -765,7 +765,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_array_fully_replicated_shard(self):
|
||||
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
arr, inp_data = create_array(
|
||||
inp_shape, jax.sharding.NamedSharding(global_mesh, P()))
|
||||
@ -776,7 +776,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(arr.addressable_data(0), inp_data)
|
||||
|
||||
def test_shard_array_to_fully_replicated(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(global_mesh, P())
|
||||
arr = jnp.arange(16)
|
||||
self.assertFalse(arr._committed)
|
||||
@ -786,7 +786,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, arr * 2)
|
||||
|
||||
def test_fully_replicated_donated_array_is_deleted(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(global_mesh, P())
|
||||
arr = jnp.arange(16)
|
||||
arr_copy = arr.copy()
|
||||
@ -804,7 +804,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(shard.data.dtype, dtype)
|
||||
|
||||
def test_make_array_from_callback_global_array(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P())
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, sharding)
|
||||
@ -823,7 +823,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_make_array_from_process_data_single_host_data_sharding(self):
|
||||
data = np.ones((1, 512))
|
||||
mesh = jtu.create_global_mesh((1, 1), ('x', 'unused'))
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'unused'))
|
||||
sharding_spec = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec('x')
|
||||
)
|
||||
@ -838,7 +838,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_mesh_pspec_sharding_interface(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
pspec = P('y', 'x')
|
||||
global_shape = (8, 4)
|
||||
mp_sharding = jax.sharding.NamedSharding(mesh, pspec)
|
||||
@ -855,7 +855,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
|
||||
def test_util_clear_cache(self):
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
mesh = jtu.create_mesh((1,), ('x',))
|
||||
s = NamedSharding(mesh, P())
|
||||
s.devices_indices_map((8,))
|
||||
jax.clear_caches()
|
||||
@ -874,7 +874,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_op_sharding_indices(self, pspec):
|
||||
shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
ops = jax.sharding.GSPMDSharding(
|
||||
list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape)))
|
||||
@ -892,12 +892,12 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_shard_shape(self, pspec, expected_shard_shape):
|
||||
shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
self.assertEqual(mps.shard_shape(shape), expected_shard_shape)
|
||||
|
||||
def test_uneven_shard_error(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -930,7 +930,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_is_compatible_error(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
|
||||
mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
|
||||
mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None))
|
||||
new_mps = jax.sharding.NamedSharding._from_parsed_pspec(
|
||||
mps.mesh, mps._parsed_pspec)
|
||||
@ -982,7 +982,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self, pspec, shape, axes, transpose):
|
||||
value_shape = (8, 4)
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
devices = jax.local_devices()[:8] # Taking up to 8 devices
|
||||
|
||||
@ -1038,7 +1038,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
|
||||
ndim = len(mesh_shape)
|
||||
mesh = jtu.create_global_mesh(
|
||||
mesh = jtu.create_mesh(
|
||||
mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z'))
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
original_op_sharding = mps._to_xla_hlo_sharding(ndim)
|
||||
@ -1071,7 +1071,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
axis_names = ('x', 'y', 'z')
|
||||
else:
|
||||
axis_names = ('x',)
|
||||
mesh = jtu.create_global_mesh(mesh_shape, axis_names)
|
||||
mesh = jtu.create_mesh(mesh_shape, axis_names)
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
shape = (8, 2, 4)
|
||||
mps_op_sharding = mps._to_xla_hlo_sharding(len(shape))
|
||||
@ -1086,7 +1086,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
def test_devices_sharding_respects_init_mesh_shape(self):
|
||||
value_shape = (8, 4)
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
devices_sharding = jax.sharding.PositionalSharding(mesh.devices)
|
||||
@ -1140,14 +1140,14 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(ps._device_assignment, new_order)
|
||||
|
||||
def test_mesh_repr(self):
|
||||
mesh = jtu.create_global_mesh((1, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
|
||||
mesh_repr = repr(mesh)
|
||||
self.assertIn('device_ids', mesh_repr)
|
||||
self.assertIn('axis_names', mesh_repr)
|
||||
|
||||
def test_are_shardings_equivalent(self):
|
||||
mesh = jtu.create_global_mesh((1,), ('x'))
|
||||
mesh2 = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1,), ('x'))
|
||||
mesh2 = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
|
||||
s1 = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
s2 = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
@ -1196,7 +1196,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_devices_indices_map_good_error_message(self):
|
||||
shape = (1, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -1205,7 +1205,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
s.devices_indices_map(shape)
|
||||
|
||||
def test_scalar_input_wrong_pspec(self):
|
||||
mesh = jtu.create_global_mesh((1, ), ('x'))
|
||||
mesh = jtu.create_mesh((1, ), ('x'))
|
||||
shape = ()
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
with self.assertRaisesRegex(
|
||||
@ -1222,13 +1222,13 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertIs(mesh1, mesh2)
|
||||
|
||||
def test_mesh_str(self):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
self.assertEqual(str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2)")
|
||||
|
||||
def test_make_array_from_callback_error(self):
|
||||
mesh_shape = (2, 3)
|
||||
global_shape = tuple(np.square(mesh_shape))
|
||||
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
|
||||
mesh = jtu.create_mesh(mesh_shape, ('x', 'y'))
|
||||
pspec = P('x', 'y')
|
||||
sharding = jax.sharding.NamedSharding(mesh, pspec)
|
||||
n = math.prod(global_shape)
|
||||
@ -1257,7 +1257,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_make_array_from_single_device_arrays_bad_inputs(self):
|
||||
x = jnp.arange(10)
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
x = jax.device_put(x, s)
|
||||
|
||||
@ -1268,7 +1268,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
def test_gspmd_sharding_hash_eq(self):
|
||||
mesh = jtu.create_global_mesh((1, 1, 1), ('x', 'y', 'z'))
|
||||
mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))
|
||||
ns = NamedSharding(mesh, P('x', 'y', 'z'))
|
||||
|
||||
x1 = GSPMDSharding(mesh._flat_devices_tuple, ns._to_xla_hlo_sharding(3))
|
||||
@ -1283,14 +1283,14 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(x.device, list(x.devices())[0])
|
||||
|
||||
# For sharded arrays, x.device returns the sharding
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
x = jax.device_put(x, sharding)
|
||||
self.assertEqual(x.device, sharding)
|
||||
|
||||
def test_to_device(self):
|
||||
device = jax.devices()[-1]
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
x = jnp.ones((2, 10))
|
||||
@ -1306,7 +1306,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
class ShardyShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_long_axis_names(self):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('sequence', 'data', 'model'))
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('sequence', 'data', 'model'))
|
||||
s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model'))
|
||||
sdy_sharding = s._to_sdy_sharding(3)
|
||||
self.assertEqual(
|
||||
@ -1323,7 +1323,7 @@ class ShardyShardingTest(jtu.JaxTestCase):
|
||||
'#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>')
|
||||
|
||||
def test_unconstrained(self):
|
||||
mesh = jtu.create_global_mesh((8,), ('x',))
|
||||
mesh = jtu.create_mesh((8,), ('x',))
|
||||
s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x'))
|
||||
sdy_sharding = s._to_sdy_sharding(3)
|
||||
self.assertEqual(
|
||||
@ -1351,7 +1351,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
32, x.shape)
|
||||
return bits + x
|
||||
|
||||
mesh = jtu.create_global_mesh((num_devices,), ('x',))
|
||||
mesh = jtu.create_mesh((num_devices,), ('x',), iota_order=True)
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
n = num_devices ** 2
|
||||
@ -1387,7 +1387,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
|
||||
global_shape = tuple(np.square(mesh_shape))
|
||||
|
||||
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
|
||||
mesh = jtu.create_mesh(mesh_shape, ('x', 'y'))
|
||||
s = jax.sharding.NamedSharding(mesh, pspec)
|
||||
|
||||
n = math.prod(global_shape)
|
||||
|
@ -1098,7 +1098,7 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
return jnp.square(x)
|
||||
|
||||
f = jax.jit(f_)
|
||||
mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
mesh = jtu.create_mesh((2,), ('x'))
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s)
|
||||
|
||||
@ -1114,7 +1114,7 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
return jnp.square(x)
|
||||
|
||||
f = pjit.pjit(f_)
|
||||
mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
mesh = jtu.create_mesh((2,), ('x'))
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s)
|
||||
|
||||
|
@ -49,7 +49,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def test_auto_layout(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape1 = (128, 128)
|
||||
shape2 = (128, 128)
|
||||
s1 = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -116,7 +116,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def test_default_layout(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -157,7 +157,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def test_in_layouts_out_layouts(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (8, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -183,7 +183,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def test_sharding_and_layouts(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -226,7 +226,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out2, out6)
|
||||
|
||||
def test_no_error_dced_args(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
shape = (8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -247,7 +247,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def test_aot_layout_mismatch(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (256, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
@ -283,7 +283,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
out_cpu, out_cpu).compile() # doesn't crash
|
||||
|
||||
def test_device_put_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (8, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -326,7 +326,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
Layout(compiled.output_layouts[0], None)
|
||||
|
||||
def test_layout_on_sds(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -345,7 +345,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO))
|
||||
|
||||
def test_make_array_from_callback(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
||||
@ -370,7 +370,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
np_inp.shape, Layout(None, None), lambda idx: np_inp[idx])
|
||||
|
||||
def test_wsc_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -393,7 +393,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
|
||||
def test_wsc_bfloat16_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
||||
@ -430,7 +430,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_concrete_layout_jit(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -474,7 +474,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def test_concrete_layout_in_shardings(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -528,7 +528,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out4, np_inp + 1)
|
||||
|
||||
def test_layout_donation(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -544,7 +544,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_auto(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (128, 16)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -559,7 +559,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_matching_in_and_out(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (128, 16)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -577,7 +577,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices('cpu', 'gpu')
|
||||
def test_layout_donation_mismatching_in_and_out_fails(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16*2, 32016*2)
|
||||
np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
||||
|
@ -47,7 +47,7 @@ def get_memory_kinds_from_executable(f, args):
|
||||
|
||||
|
||||
def _create_inputs(shape, pspec, mem_kind=None):
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, pspec, memory_kind=mem_kind)
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -71,7 +71,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_canonicalize_memory_kind(self, name):
|
||||
if name == "named_sharding":
|
||||
mesh = jtu.create_global_mesh((1,), "x")
|
||||
mesh = jtu.create_mesh((1,), "x")
|
||||
ns = NamedSharding(mesh, P("x"))
|
||||
self.assertEqual(ns.memory_kind, self._default_memory_kind)
|
||||
elif name == "positional_sharding":
|
||||
@ -96,7 +96,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Could not find memory addressable by device.*"
|
||||
):
|
||||
mesh = jtu.create_global_mesh((1,), ("x",))
|
||||
mesh = jtu.create_mesh((1,), ("x",))
|
||||
NamedSharding(mesh, P("x"), memory_kind="hbm")
|
||||
elif name == "positional_sharding":
|
||||
with self.assertRaisesRegex(
|
||||
@ -128,7 +128,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
self.skipTest("TPU memory kind test.")
|
||||
|
||||
if name == "named_sharding":
|
||||
mesh = jtu.create_global_mesh((1,), ("x",))
|
||||
mesh = jtu.create_mesh((1,), ("x",))
|
||||
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
||||
elif name == "positional_sharding":
|
||||
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
||||
@ -146,7 +146,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_sharding_eq(self, name):
|
||||
if name == "named_sharding":
|
||||
mesh = jtu.create_global_mesh((1,), ("x",))
|
||||
mesh = jtu.create_mesh((1,), ("x",))
|
||||
s1 = NamedSharding(mesh, P("x"))
|
||||
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
||||
self.assertEqual(s1, s2)
|
||||
@ -164,7 +164,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(s1, s2)
|
||||
|
||||
def test_sharding_equivalent(self):
|
||||
mesh = jtu.create_global_mesh((1,), ("x",))
|
||||
mesh = jtu.create_mesh((1,), ("x",))
|
||||
ndim = 2
|
||||
ns1 = NamedSharding(mesh, P("x"))
|
||||
gs1 = GSPMDSharding(
|
||||
@ -215,7 +215,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_device_put_host_to_hbm(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
@ -231,7 +231,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_device_put_hbm_to_host(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
||||
inp = jnp.arange(16).reshape(8, 2)
|
||||
|
||||
@ -314,7 +314,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(yashkatariya): Enable this once we can compute on host.
|
||||
# def test_device_put_resharding(self):
|
||||
# mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
# mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
# s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
|
||||
# s_hbm = s_host.with_memory_kind("device")
|
||||
# np_inp = np.arange(16).reshape(8, 2)
|
||||
@ -341,7 +341,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_device_put_numpy_array(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
|
||||
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
||||
@ -462,7 +462,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_parameter_streaming_with_scalar_and_constant(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
scalar_inp = 1
|
||||
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
||||
|
||||
@ -488,7 +488,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_parameter_and_output_streaming_with_array(self):
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
||||
inp_host = jax.device_put(np_inp, s_host)
|
||||
@ -540,7 +540,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_identity_jit_host_to_device_and_vice_versa(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
@ -560,7 +560,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_host.sharding, s_host)
|
||||
|
||||
def test_parameter_streaming_inside_scan(self):
|
||||
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
np_inp = np.arange(4096.0).reshape(16, 16, 16)
|
||||
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
|
||||
arr_host = jax.device_put(np_inp, s_host)
|
||||
@ -582,7 +582,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_output_streaming(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test is flaky on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((1, 1), ("x", "y"))
|
||||
mesh = jtu.create_mesh((1, 1), ("x", "y"))
|
||||
np_inp = np.arange(16.0).reshape(8, 2)
|
||||
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
||||
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
||||
@ -619,7 +619,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
np_inp = np.arange(4096).reshape(16, 16, 16)
|
||||
s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device")
|
||||
arr_hbm = jax.device_put(np_inp, s_hbm)
|
||||
@ -680,7 +680,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(executable_kind, expected_kind)
|
||||
|
||||
def test_compute_no_inputs(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('data'))
|
||||
mesh = jtu.create_mesh((4,), ('data'))
|
||||
|
||||
tpu_sharding = NamedSharding(mesh, P('data'))
|
||||
cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host')
|
||||
@ -698,7 +698,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
def test_compute_no_inputs_host_replicated(self):
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3:
|
||||
self.skipTest("This test requires an xla_version >= 3.")
|
||||
mesh = jtu.create_global_mesh((4,), ('data'))
|
||||
mesh = jtu.create_mesh((4,), ('data'))
|
||||
|
||||
tpu_sharding = NamedSharding(mesh, P('data'))
|
||||
cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
||||
@ -843,7 +843,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertLen(out, 2)
|
||||
|
||||
def test_nested_no_op_compute(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -868,7 +868,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
def test_sharded_compute_on_host(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -921,7 +921,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
def test_host_offload_in_custom_vjp_sharded(self):
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
@jax.custom_vjp
|
||||
@ -1005,7 +1005,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
def test_pure_host_data_and_compute(self):
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr_host = jax.device_put(np_inp, s)
|
||||
@ -1032,7 +1032,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertArraysAllClose(out, jnp.sin(inp * 2))
|
||||
|
||||
def test_compute_per_annotation(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
s = NamedSharding(mesh, P("x", "y"))
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -1223,7 +1223,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
|
||||
|
||||
def test_sharding_devices_indices_map_cache_hit(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
shape = (8, 2)
|
||||
s1 = NamedSharding(mesh, P("x", "y"))
|
||||
s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
||||
@ -1238,7 +1238,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
def test_no_donation_across_memory_kinds(self):
|
||||
if xb.using_pjrt_c_api():
|
||||
raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API")
|
||||
mesh = jtu.create_global_mesh((2, 1), ("x", "y"))
|
||||
mesh = jtu.create_mesh((2, 1), ("x", "y"))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s_hbm = NamedSharding(mesh, P("x"))
|
||||
s_host = s_hbm.with_memory_kind("pinned_host")
|
||||
@ -1257,7 +1257,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertNotDeleted(inp)
|
||||
|
||||
def test_single_mem_kind_donation_default_mem_kind(self):
|
||||
mesh = jtu.create_global_mesh((2,), "x")
|
||||
mesh = jtu.create_mesh((2,), "x")
|
||||
s = NamedSharding(mesh, P())
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
|
||||
@ -1273,7 +1273,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertDeleted(x)
|
||||
|
||||
def test_compute_offload_inside_shmap(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -1327,7 +1327,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertArraysAllClose(out, expected_out, rtol=1e-3)
|
||||
|
||||
def test_mem_kind_donation_pinned_host(self):
|
||||
mesh = jtu.create_global_mesh((2,), "x")
|
||||
mesh = jtu.create_mesh((2,), "x")
|
||||
s = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
||||
s_dev = s.with_memory_kind('device')
|
||||
|
||||
@ -1349,7 +1349,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
|
||||
@parameterized.parameters("pinned_host", "device")
|
||||
def test_identity_mem_kind_donation(self, mem_kind):
|
||||
mesh = jtu.create_global_mesh((2,), "x")
|
||||
mesh = jtu.create_mesh((2,), "x")
|
||||
s = NamedSharding(mesh, P(), memory_kind=mem_kind)
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
|
||||
@ -1367,7 +1367,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.run_on_devices('tpu')
|
||||
def test_aot_device_implicit_transfer(self):
|
||||
mesh = jtu.create_global_mesh((1,), 'x')
|
||||
mesh = jtu.create_mesh((1,), 'x')
|
||||
np_inp = np.arange(8)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P()))
|
||||
|
||||
@ -1397,7 +1397,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
|
||||
def test_remat_jaxpr_offloadable(self):
|
||||
mesh = jtu.create_global_mesh((2,), ("x",))
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
|
||||
|
||||
def policy(prim, *avals, **params):
|
||||
@ -1440,7 +1440,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
def test_remat_scan_jaxpr_offloadable(self):
|
||||
mesh = jtu.create_global_mesh((2,), ("x",))
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
shape = (256, 128)
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
s = NamedSharding(mesh, P("x"))
|
||||
@ -1498,7 +1498,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
def test_remat_scan_layout_change_offloadable(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Remat scan does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2,), ("x",))
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
shape = (256, 128)
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
s = NamedSharding(mesh, P("x"))
|
||||
|
@ -354,7 +354,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def test_pjit_gda_multi_input_multi_output(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = jtu.create_global_mesh((8, 2), ("x", "y"))
|
||||
global_mesh = jtu.create_mesh((8, 2), ("x", "y"))
|
||||
global_input_shape = (16, 2)
|
||||
global_input_data = np.arange(
|
||||
util.prod(global_input_shape)).reshape(global_input_shape)
|
||||
@ -558,7 +558,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
def test_pjit_gda_eval_shape(self):
|
||||
jax.distributed.initialize()
|
||||
|
||||
with jtu.create_global_mesh((16,), ("x")):
|
||||
with jtu.create_mesh((16,), ("x")):
|
||||
|
||||
@functools.partial(pjit.pjit,
|
||||
in_shardings=jax.sharding.PartitionSpec(None),
|
||||
|
@ -54,7 +54,7 @@ class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skip("Test failing in CI")
|
||||
def testPGLEProfilerGetFDOProfile(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
||||
@partial(
|
||||
jax.jit,
|
||||
@ -83,7 +83,7 @@ class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skip("Test failing in CI")
|
||||
def testPGLEProfilerGetFDOProfileLarge(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
its = 500
|
||||
|
||||
@partial(
|
||||
@ -112,7 +112,7 @@ class PgleTest(jtu.JaxTestCase):
|
||||
self.assertEqual(fdo_profile.count(b'custom'), its)
|
||||
|
||||
def testAutoPgle(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
||||
@partial(
|
||||
jax.jit,
|
||||
@ -245,7 +245,7 @@ class PgleTest(jtu.JaxTestCase):
|
||||
self.assertFalse(pgle_profiler.is_fdo_consumed())
|
||||
|
||||
def testPassingFDOProfile(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
||||
@partial(
|
||||
jax.jit,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -957,7 +957,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
def test_make_array_from_callback(self):
|
||||
devices = jax.devices()
|
||||
shape = (len(devices),)
|
||||
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
||||
mesh = jtu.create_mesh((len(devices),), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
def callback(index):
|
||||
i = jnp.arange(len(devices))[index[0]]
|
||||
@ -969,7 +969,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
def test_make_array_from_single_device_arrays(self):
|
||||
devices = jax.devices()
|
||||
shape = (len(devices),)
|
||||
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
||||
mesh = jtu.create_mesh((len(devices),), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
keys = random.split(random.key(0), len(devices))
|
||||
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
|
||||
|
@ -39,7 +39,7 @@ class ShardAlikeDownstreamTest(jtu.JaxTestCase):
|
||||
|
||||
def test_full_like(self):
|
||||
x = jnp.arange(16, dtype='float32').reshape(8, 2)
|
||||
mesh = jtu.create_global_mesh((8,), ("i",))
|
||||
mesh = jtu.create_mesh((8,), ("i",))
|
||||
x = jax.device_put(x, NamedSharding(mesh, P('i', None)))
|
||||
y = jnp.full_like(x, 1)
|
||||
self.assertEqual(x.sharding, y.sharding)
|
||||
@ -51,7 +51,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
|
||||
def test_basic(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -68,7 +68,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp * np_inp * 4)
|
||||
|
||||
def test_output_sharded_alike_input(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -83,7 +83,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
|
||||
def test_arange_shard_alike_jit(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -98,7 +98,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_different_shapes(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x',))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -113,7 +113,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
f(inp)
|
||||
|
||||
def test_double_shard_alike(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -131,7 +131,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
def test_shard_like_eager(self):
|
||||
mesh = jtu.create_global_mesh((4, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 1), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -145,7 +145,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_shard_map(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -167,7 +167,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.sharding, s)
|
||||
|
||||
def test_grad(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
np_inp = np.arange(8.)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -188,7 +188,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
jax.grad(jax.jit(f))(inp) # doesn't crash
|
||||
|
||||
def test_shard_input_as_output(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
np_inp = np.arange(8.)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
@ -218,7 +218,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out4.sharding, s)
|
||||
|
||||
def test_shard_alike_inputs(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8.)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
rep_s = NamedSharding(mesh, P())
|
||||
@ -237,7 +237,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.sharding, s)
|
||||
|
||||
def test_vmap_one_mapped(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(2)
|
||||
s = NamedSharding(mesh, P('y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -256,7 +256,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np.tile(np_inp, [8, 1]))
|
||||
|
||||
def test_vmap_both_mapped(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp1 = jax.device_put(np_inp, s)
|
||||
|
@ -749,7 +749,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_shard_map_abstract_mesh(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
@ -807,7 +807,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_shmap_abstract_mesh_errors(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
|
||||
|
||||
@ -819,7 +819,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
out_specs=P('x'))(jnp.arange(8))
|
||||
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
mesh2 = jtu.create_global_mesh((2,), 'y')
|
||||
mesh2 = jtu.create_mesh((2,), 'y')
|
||||
abs_mesh2 = AbstractMesh(mesh2.shape_tuple)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -901,7 +901,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@jax.legacy_prng_key('allow')
|
||||
def test_prngkeyarray_eager(self):
|
||||
# https://github.com/google/jax/issues/15398
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
@ -917,7 +917,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
_ = g(sharded_rng) # don't crash!
|
||||
|
||||
def test_functools_partial_rank_error(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial
|
||||
def f(x):
|
||||
@ -929,7 +929,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
g(x)
|
||||
|
||||
def test_in_specs_none_error(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
def f(x): return x
|
||||
|
||||
@ -943,7 +943,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash
|
||||
|
||||
def test_scan_rep_rule(self):
|
||||
mesh = jtu.create_global_mesh((2, 2,), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2,), ('x', 'y'))
|
||||
|
||||
def f(x, y, z):
|
||||
x, y, z = x.sum(), y.sum(), z.sum()
|
||||
@ -996,7 +996,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
return foo(x), 3. * x_dot
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.))
|
||||
self.assertAllClose(y, (2. * jnp.arange(4.)).sum())
|
||||
@ -1015,7 +1015,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.))
|
||||
self.assertAllClose(y, (2. * jnp.arange(4.)).sum())
|
||||
@ -1029,7 +1029,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
if jit:
|
||||
foo = jax.jit(foo)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))()
|
||||
expected = jnp.arange(4.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -1045,7 +1045,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
if jit:
|
||||
foo = jax.jit(foo)
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(),
|
||||
out_specs=P('i', 'j'))()
|
||||
expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2))
|
||||
@ -1056,7 +1056,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans3, expected3, check_dtypes=False)
|
||||
|
||||
def test_axis_index_eager(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P())
|
||||
def foo():
|
||||
@ -1068,7 +1068,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_jaxpr_shardings_with_no_outputs(self):
|
||||
# https://github.com/google/jax/issues/15385
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('i'))
|
||||
@ -1084,7 +1084,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
g(np.arange(32)) # don't crash
|
||||
|
||||
def test_device_put(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(x):
|
||||
@ -1109,7 +1109,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_key_array_with_replicated_last_tile_dim(self):
|
||||
# See https://github.com/google/jax/issues/16137
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 4), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 4), ('i', 'j'))
|
||||
|
||||
def f(rng):
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),
|
||||
@ -1149,7 +1149,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, inputs_dce, order=2, modes=['rev'])
|
||||
|
||||
def test_returned_out_sharding(self):
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(jnp.zeros((2, 2)), s)
|
||||
out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp)
|
||||
@ -1157,7 +1157,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, inp)
|
||||
|
||||
def test_dce(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
||||
def f(x, y, z):
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')),
|
||||
@ -1208,7 +1208,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
check_diff=False)
|
||||
|
||||
def test_post_process_partial_eval_with_scalar_res(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh,
|
||||
in_specs=P(), out_specs=P())())(2.0)
|
||||
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
|
||||
@ -1239,7 +1239,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call
|
||||
@ -1248,7 +1248,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_rewrite_post_process_call(self):
|
||||
# We shouldn't hit post_process_call here because of RewriteTrace's dynamic
|
||||
# behavior (i.e. no data dependence).
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
@ -1270,7 +1270,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
return foo(x), 2. * x_dot
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
if jit:
|
||||
@ -1298,7 +1298,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
if jit:
|
||||
@ -1326,7 +1326,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
if jit:
|
||||
@ -1343,7 +1343,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_same_pspec_eager_shard_map(self):
|
||||
# This behavior is not guaranteed by JAX and this test can be changed if
|
||||
# the behavior changes.
|
||||
mesh = jtu.create_global_mesh((1, 4, 1), ('data', 'seq', 'model'))
|
||||
mesh = jtu.create_mesh((1, 4, 1), ('data', 'seq', 'model'))
|
||||
|
||||
def f(x):
|
||||
return x * x + 2
|
||||
@ -1371,7 +1371,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
|
||||
in_specs=(P(), P('x')), out_specs=P('x'))
|
||||
if jit:
|
||||
@ -1406,7 +1406,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1)
|
||||
return y
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo_scan(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
if jit:
|
||||
@ -1421,7 +1421,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
||||
|
||||
def test_transpose_identity(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def f(x):
|
||||
@ -1446,7 +1446,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertLen(e2.params['jaxpr'].eqns, 1)
|
||||
|
||||
def test_fanout_specs_transpose_to_psum(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x'))
|
||||
def f(x):
|
||||
@ -1459,7 +1459,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(e2.params['axes'], ('x',))
|
||||
|
||||
def test_fanin_psum_transposes_to_fanout(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P())
|
||||
def f(x):
|
||||
@ -1471,7 +1471,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(str(e1.primitive), 'pbroadcast')
|
||||
|
||||
def test_psum_with_implicit_fanout_self_transposes(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
@ -1484,7 +1484,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(str(e2.primitive), 'pbroadcast')
|
||||
|
||||
def test_rewrite_binops(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x'))
|
||||
def f(x, y):
|
||||
@ -1497,7 +1497,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(e.params['axes'], ('x',))
|
||||
|
||||
def test_rewrite_scan(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
@ -1515,7 +1515,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_check_rep_false_grads(self):
|
||||
# This test is redundant with the systematic tests below, but it serves as a
|
||||
# direct regression test for a bug.
|
||||
mesh = jtu.create_global_mesh((4,), ('heads',))
|
||||
mesh = jtu.create_mesh((4,), ('heads',))
|
||||
|
||||
def f(q, k, v):
|
||||
|
||||
@ -1549,7 +1549,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters(it.product([True, False], repeat=2))
|
||||
def test_res_forwarding_optimization(self, jit, remat):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(x):
|
||||
@ -1572,7 +1572,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters(it.product([True, False], repeat=2))
|
||||
def test_res_forwarding_optimization_complex(self, jit, remat):
|
||||
# like the above test, but a different function `f`
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(x):
|
||||
@ -1594,7 +1594,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_check_rep_failure_inside_rule(self, jit):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
|
||||
def loss(w, x):
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
|
||||
@ -1608,7 +1608,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jax.grad(loss)(3.0, jnp.arange(8.)) # don't crash
|
||||
|
||||
def test_conv_general_dilated(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
|
||||
dot = partial(lax.conv_general_dilated, window_strides=(),
|
||||
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
||||
@ -1624,25 +1624,25 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, a @ b, check_dtypes=False, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_cumsum(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
x = jnp.arange(8.)
|
||||
shard_map(jnp.cumsum, mesh=mesh, in_specs=P('i'), out_specs=P('i')
|
||||
)(x) # don't crash
|
||||
|
||||
def test_custom_jvp_inside_jit(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('batch',))
|
||||
mesh = jtu.create_mesh((4,), ('batch',))
|
||||
x = shard_map(jax.jit(jax.nn.relu),
|
||||
mesh=mesh, in_specs=P('batch'),
|
||||
out_specs=P('batch'))(jnp.arange(16.)) # don't crash
|
||||
|
||||
def test_random_normal_rules(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
keys = jax.random.split(jax.random.key(0), 4)
|
||||
shard_map(lambda k: jax.random.normal(k[0], (1,)),
|
||||
mesh=mesh, in_specs=P('i'), out_specs=P('i'))(keys) # don't crash
|
||||
|
||||
def test_erf_rules(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('i',))
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
x = jnp.arange(16.)
|
||||
shard_map(jax.lax.erf,
|
||||
mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash
|
||||
@ -1732,7 +1732,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
modes=['rev'], atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_partial_auto(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
x = jax.lax.with_sharding_constraint(
|
||||
@ -1759,7 +1759,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_sharded_prng_with_abstract_mesh(self):
|
||||
shape = (8, 2, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape)
|
||||
key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl)
|
||||
@ -1774,7 +1774,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
def test_partial_auto_error_wsc_manual(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
x = jax.lax.with_sharding_constraint(
|
||||
@ -1797,7 +1797,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
f(v)
|
||||
|
||||
def test_partial_auto_error_invalid_auto(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
x = jax.lax.with_sharding_constraint(
|
||||
@ -1820,7 +1820,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
f(v)
|
||||
|
||||
def test_partial_auto_error_wrong_in_specs(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
x = jax.lax.with_sharding_constraint(
|
||||
@ -1843,7 +1843,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
f(v)
|
||||
|
||||
def test_nested_partial_auto(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
return x * x
|
||||
@ -1866,7 +1866,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
|
||||
def test_axis_size_1_partial_auto(self):
|
||||
mesh = jtu.create_global_mesh((1, 2, 2), ('i', 'j', 'k'))
|
||||
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))
|
||||
|
||||
def h(x):
|
||||
return x * x
|
||||
@ -1884,7 +1884,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
|
||||
def test_partial_auto_of_pjit(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def h():
|
||||
def _make_zeros():
|
||||
@ -1901,7 +1901,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
|
||||
|
||||
def test_partial_auto_of_pjit_different_mesh(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l'))
|
||||
|
||||
def h():
|
||||
@ -1920,7 +1920,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
|
||||
# https://github.com/google/jax/pull/21032
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
@ -1937,7 +1937,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self):
|
||||
# https://github.com/google/jax/pull/21056
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
||||
@partial(jax.remat, policy=lambda *_, **__: True)
|
||||
@partial(
|
||||
@ -1955,7 +1955,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_shmap_residuals_axis_names_in_mesh_order(self):
|
||||
# https://github.com/google/jax/issues/21236
|
||||
mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a'))
|
||||
mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a'))
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
@ -1975,7 +1975,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_vmap_spmd_axis_name_error(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
@ -2005,7 +2005,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jax.vmap(g, spmd_axis_name='i')(xs)
|
||||
|
||||
def test_in_spec_none(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
||||
x = jnp.arange(8).reshape(4, 2)
|
||||
|
||||
@ -2056,7 +2056,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
||||
|
||||
def test_in_spec_none_divisibility_errors(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
x = jnp.arange(4).reshape(2, 2)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'divisible'):
|
||||
@ -2078,7 +2078,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
)((object(), object()), x)
|
||||
|
||||
def test_in_spec_none_rank_errors(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
x = jnp.arange(4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'rank'):
|
||||
@ -2101,7 +2101,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_custom_linear_solve_rep_rules(self):
|
||||
# https://github.com/google/jax/issues/20162
|
||||
mesh = jtu.create_global_mesh((1,), ('i',))
|
||||
mesh = jtu.create_mesh((1,), ('i',))
|
||||
a = jnp.array(1).reshape(1, 1)
|
||||
b = jnp.array(1).reshape(1)
|
||||
|
||||
@ -2113,7 +2113,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
_ = f(a, b) # don't crash
|
||||
|
||||
def test_temporary_error_suppression_flag(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('i',))
|
||||
mesh = jtu.create_mesh((2,), ('i',))
|
||||
|
||||
def f(x, y):
|
||||
z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True),
|
||||
@ -2375,7 +2375,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
|
||||
@staticmethod
|
||||
def make_mesh(mesh_shape):
|
||||
return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
|
||||
return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
|
Loading…
x
Reference in New Issue
Block a user