rocm_jax/tests/pallas/tpu_pallas_distributed_test.py
2025-01-28 21:18:47 +00:00

583 lines
19 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for distributed pallas TPU operations."""
import functools
import os
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import test_util as jtu
from jax.experimental import mesh_utils
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
P = jax.sharding.PartitionSpec
partial = functools.partial
class PallasCallRemoteDMATest(parameterized.TestCase):
def setUp(self):
super().setUp()
if jax.device_count() < 2:
self.skipTest('Only >=2 devices are supported.')
if not jtu.is_device_tpu(5, 'e'):
self.skipTest('Only works with TPU v5e.')
@parameterized.named_parameters(
('vmem', pltpu.TPUMemorySpace.VMEM),
('hbm', pltpu.TPUMemorySpace.ANY),
)
def test_basic_remote_vmem_dma(self, mem):
# Implements very simple collective permute
def kernel(x_ref, y_ref):
def body(ready_sem, send_sem, recv_sem):
dev_id = pltpu.device_id()
other_dev_id = 1 - dev_id
pltpu.semaphore_signal(ready_sem, device_id=other_dev_id,
device_id_type=pltpu.DeviceIdType.LOGICAL)
pltpu.semaphore_wait(ready_sem)
copy_done = pltpu.async_remote_copy(
x_ref, y_ref, send_sem, recv_sem, other_dev_id,
device_id_type=pltpu.DeviceIdType.LOGICAL,
)
copy_done.wait_send()
copy_done.wait_recv()
pl.run_scoped(
body,
pltpu.SemaphoreType.REGULAR,
pltpu.SemaphoreType.DMA,
pltpu.SemaphoreType.DMA,
)
x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128))
def body(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=mem)],
out_specs=pl.BlockSpec(memory_space=mem),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
devices = jax.devices()[:2]
mesh = jax.sharding.Mesh(devices, ['x'])
y = jax.jit(
shard_map.shard_map(
body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False
)
)(x)
expected = jnp.concatenate([x[8:], x[:8]])
np.testing.assert_allclose(y, expected)
@parameterized.named_parameters(
('left', 'left'),
('right', 'right')
)
def test_pallas_call_axis_index(self, direction):
# Implements very simple collective permute
def kernel(x_ref, y_ref):
def body(ready_sem, send_sem, recv_sem):
my_id = lax.axis_index('x')
num_devices = lax.psum(1, 'x')
if direction == 'right':
neighbor = lax.rem(my_id + 1, num_devices)
else:
neighbor = lax.rem(my_id - 1, num_devices)
# Neighbor might be negative here so we add num_devices in case
neighbor = jnp.where(neighbor < 0, neighbor + num_devices, neighbor)
pltpu.semaphore_signal(ready_sem, device_id=neighbor)
pltpu.semaphore_wait(ready_sem)
copy_done = pltpu.async_remote_copy(
x_ref, y_ref, send_sem, recv_sem, device_id=neighbor
)
copy_done.wait_send()
copy_done.wait_recv()
pl.run_scoped(
body,
pltpu.SemaphoreType.REGULAR,
pltpu.SemaphoreType.DMA,
pltpu.SemaphoreType.DMA,
)
num_devices = jax.local_device_count()
x = jnp.arange(num_devices * 8 * 128).reshape((num_devices * 8, 128))
def body(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=x,
)(x)
device_mesh = mesh_utils.create_device_mesh(
(jax.device_count(),), jax.devices())
mesh = jax.sharding.Mesh(device_mesh, ['x'])
y = jax.jit(
shard_map.shard_map(
body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False
)
)(x)
if direction == 'right':
expected = jnp.concatenate([x[-8:], x[:-8]])
else:
expected = jnp.concatenate([x[8:], x[:8]])
np.testing.assert_allclose(y, expected)
@parameterized.named_parameters(('left', 'left'), ('right', 'right'))
def test_pallas_call_axis_index_2d_mesh(self, direction):
# Implements very simple collective permute in a 2D mesh.
def kernel(x_ref, y_ref):
def body(ready_sem, send_sem, recv_sem):
my_id = lax.axis_index('x')
my_other_id = lax.axis_index('y')
axis_size = lax.psum(1, 'x')
if direction == 'right':
neighbor = lax.rem(my_id + 1, axis_size)
else:
neighbor = lax.rem(my_id - 1, axis_size)
# Neighbor might be negative here so we add num_devices in case
neighbor = jnp.where(neighbor < 0, neighbor + axis_size, neighbor)
pltpu.semaphore_signal(ready_sem, device_id=(my_other_id, neighbor))
pltpu.semaphore_wait(ready_sem)
copy_done = pltpu.async_remote_copy(
x_ref, y_ref, send_sem, recv_sem, device_id=(my_other_id, neighbor)
)
copy_done.wait_send()
copy_done.wait_recv()
pl.run_scoped(
body,
pltpu.SemaphoreType.REGULAR,
pltpu.SemaphoreType.DMA,
pltpu.SemaphoreType.DMA,
)
axis_size = jax.device_count() // 2
x = jnp.arange(axis_size * 8 * 128).reshape((axis_size * 8, 128))
def body(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=x,
)(x)
device_mesh = mesh_utils.create_device_mesh(
(2, axis_size), jax.devices()
)
mesh = jax.sharding.Mesh(device_mesh, ['y', 'x'])
y = jax.jit(
shard_map.shard_map(
body,
mesh,
in_specs=P('x', None),
out_specs=P('x', None),
check_rep=False,
)
)(x)
if direction == 'right':
expected = jnp.concatenate([x[-8:], x[:-8]])
else:
expected = jnp.concatenate([x[8:], x[:8]])
np.testing.assert_allclose(y, expected)
def test_barrier_semaphore(self):
def kernel(x_ref, y_ref):
def body(ready_sem, send_sem, recv_sem):
my_id = lax.axis_index('x')
num_devices = lax.psum(1, 'x')
neighbor = lax.rem(my_id + 1, num_devices)
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(barrier_sem, device_id=neighbor)
pltpu.semaphore_wait(barrier_sem)
pltpu.semaphore_signal(ready_sem, device_id=neighbor)
pltpu.semaphore_wait(ready_sem)
pltpu.async_remote_copy(
x_ref, y_ref, send_sem, recv_sem, device_id=neighbor
).wait()
pl.run_scoped(
body,
pltpu.SemaphoreType.REGULAR,
pltpu.SemaphoreType.DMA,
pltpu.SemaphoreType.DMA,
)
num_devices = jax.local_device_count()
x = jnp.arange(num_devices * 8 * 128).reshape((num_devices * 8, 128))
def body(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=x,
compiler_params=dict(mosaic=dict(collective_id=0)),
)(x)
device_mesh = mesh_utils.create_device_mesh(
(jax.device_count(),), jax.devices())
mesh = jax.sharding.Mesh(device_mesh, ['x'])
y = jax.jit(
shard_map.shard_map(
body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False
)
)(x)
expected = jnp.concatenate([x[-8:], x[:-8]])
np.testing.assert_allclose(y, expected)
class PallasCallRemoteDMAInterpretTest(parameterized.TestCase):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu():
self.skipTest('Test requires TPU')
@parameterized.parameters(('left',), ('right',))
def test_interpret_remote_dma_ppermute(self, permutation):
self.skipTest("ROCm: Skipping for now")
if jax.device_count() <= 1:
self.skipTest('Test requires multiple devices.')
num_devices = jax.device_count()
if permutation == 'left':
permute_fn = lambda x: lax.rem(x + num_devices - 1, num_devices)
else:
permute_fn = lambda x: lax.rem(x + num_devices + 1, num_devices)
# Construct a kernel which performs a ppermute based on permute_fn.
def test_kernel(x_ref,
o_ref,
copy_send_sem,
copy_recv_sem,
):
o_ref[...] = jnp.zeros_like(o_ref[...])
my_id = lax.axis_index('x')
dst_device = permute_fn(my_id)
input_to_output_copy = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=o_ref,
send_sem=copy_send_sem,
recv_sem=copy_recv_sem,
device_id=dst_device,
device_id_type=pltpu.DeviceIdType.LOGICAL,
)
input_to_output_copy.start()
input_to_output_copy.wait()
out_shape = (jax.ShapeDtypeStruct((8, 128), jnp.float32))
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 2
)
)
devices = mesh_utils.create_device_mesh((num_devices,))
mesh = jax.sharding.Mesh(devices, 'x')
sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
unsharded_arr = jax.random.normal(
jax.random.key(0), shape=(8, 128 * num_devices))
sharded_arr = jax.device_put(unsharded_arr, sharding)
kernel = pl.pallas_call(
test_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=True,
)
compiled_func = jax.jit(shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x'),
check_rep=False))
result = compiled_func(sharded_arr)
perm = tuple((src, permute_fn(src)) for src in range(num_devices))
perm = jax.tree_util.tree_map(int, perm)
def lax_permute(x):
return lax.ppermute(x, 'x', perm)
expected = jax.jit(shard_map.shard_map(lax_permute,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x')))(sharded_arr)
np.testing.assert_array_equal(result, expected)
def test_interpret_remote_dma_asymmetrical_indexer(self):
# Test DMAs where destination slices are not the same.
if jax.local_device_count() <= 1:
self.skipTest('Test requires multiple devices.')
if not jtu.is_device_tpu(5, 'e'):
self.skipTest('Only works with TPU v5e.')
num_devices = jax.local_device_count()
def test_kernel(x_ref,
output_ref,
send_sem,
recv_sem):
output_ref[...] = jnp.zeros_like(output_ref[...])
my_id = lax.axis_index('x')
even_device = lax.rem(my_id, 2)
odd_device = 1 - even_device
neighbor = lax.rem(my_id + 1, num_devices)
# If the device_id is even, we copy to output_ref[1].
# If it's odd, we copy to output_ref[0].
@pl.when(even_device)
def _():
remote_dma = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=output_ref.at[1],
send_sem=send_sem,
recv_sem=recv_sem,
device_id=neighbor,
)
remote_dma.start()
remote_dma.wait()
@pl.when(odd_device)
def _():
remote_dma = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=output_ref.at[0],
send_sem=send_sem,
recv_sem=recv_sem,
device_id=neighbor,
)
remote_dma.start()
remote_dma.wait()
out_shape = (jax.ShapeDtypeStruct((2, 8, 128), jnp.float32))
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 2
)
)
devices = mesh_utils.create_device_mesh(( num_devices,))
mesh = jax.sharding.Mesh(devices, P('x'))
sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
unsharded_arr = jax.random.normal(
jax.random.key(0), shape=(8, 128 * num_devices))
sharded_arr = jax.device_put(unsharded_arr, sharding)
# Compare interpret mode result to non-interpret mode result.
kernel = pl.pallas_call(
test_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=True,
)
compiled_func = jax.jit(shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x'),
check_rep=False))
result_interpret = compiled_func(sharded_arr)
kernel = pl.pallas_call(
test_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
)
compiled_func = jax.jit(shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x'),
check_rep=False))
result_noninterpret = compiled_func(sharded_arr)
np.testing.assert_allclose(result_interpret,
result_noninterpret,
atol=1e-5,
rtol=1e-3)
def test_interpret_remote_dma_asymmetrical_refs(self):
# Test DMAs where dst refs are not the same.
self.skipTest('Known failure.')
num_devices = jax.local_device_count()
def test_kernel(x_ref,
even_output,
odd_output,
send_sem,
recv_sem):
even_output[...] = jnp.zeros_like(even_output[...])
odd_output[...] = jnp.zeros_like(odd_output[...])
my_id = lax.axis_index('x')
even_device = lax.rem(my_id, 2)
odd_device = 1 - even_device
neighbor = lax.rem(my_id + 1, num_devices)
@pl.when(even_device)
def _():
remote_dma = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=even_output,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=neighbor,
device_id_type=pltpu.DeviceIdType.LOGICAL,
)
remote_dma.start()
remote_dma.wait()
@pl.when(odd_device)
def _():
remote_dma = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=odd_output,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=neighbor,
device_id_type=pltpu.DeviceIdType.LOGICAL,
)
remote_dma.start()
remote_dma.wait()
out_shape = (jax.ShapeDtypeStruct((8, 128), jnp.float32))
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 2
)
)
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, P(None, 'x'))
sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
unsharded_arr = jax.random.normal(
jax.random.key(0), shape=(8, 128 * num_devices))
sharded_arr = jax.device_put(unsharded_arr, sharding)
# Compare interpret mode result to non-interpret mode result.
kernel = pl.pallas_call(
test_kernel,
out_shape=(out_shape, out_shape),
grid_spec=grid_spec,
interpret=True,
)
compiled_func = jax.jit(shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x'),
check_rep=False))
result_interpret = compiled_func(sharded_arr)
kernel = pl.pallas_call(
test_kernel,
out_shape=(out_shape, out_shape),
grid_spec=grid_spec,
)
compiled_func = jax.jit(shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x'),
check_rep=False))
result_noninterpret = compiled_func(sharded_arr)
np.testing.assert_allclose(result_interpret,
result_noninterpret,
atol=1e-5,
rtol=1e-3)
class VerificationTest(jtu.JaxTestCase):
def test_verification(self):
if (num_devices := jax.local_device_count()) <= 1:
self.skipTest('Test requires multiple devices.')
if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1:
self.skipTest('Test requires a new single-core TPU.')
def kernel_body(in_ref, out_ref, scratch_ref, send_sem, recv_sem, capacity_sem):
my_id = lax.axis_index('x')
dst_id = jnp.where(my_id == num_devices - 1, 0, my_id + 1)
src_id = jnp.where(my_id == 0, num_devices - 1, my_id - 1)
pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id)
out_ref[...] = jnp.zeros_like(out_ref)
scratch_ref[0] = in_ref[0]
@functools.partial(lax.fori_loop, 0, num_devices - 1, init_val=None)
def _(i, _):
slot = i % 2
next_slot = 1 - slot
pltpu.semaphore_wait(capacity_sem, 1)
copy = pltpu.async_remote_copy(
scratch_ref.at[slot],
scratch_ref.at[next_slot],
send_sem,
recv_sem,
device_id=dst_id,
)
out_ref[...] += scratch_ref[slot]
copy.wait()
pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id)
out_ref[...] += scratch_ref[(num_devices - 1) % 2]
pltpu.semaphore_wait(capacity_sem, 1)
kernel = pl.pallas_call(
kernel_body,
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32),
scratch_shapes=[
pltpu.VMEM((2, 128, 128), jnp.float32),
pltpu.SemaphoreType.DMA,
pltpu.SemaphoreType.DMA,
pltpu.SemaphoreType.REGULAR,
],
)
devices = mesh_utils.create_device_mesh((num_devices,))
mesh = jax.sharding.Mesh(devices, ['x'])
# This is just a smoke test to ensure that the verification does not crash.
with tempfile.TemporaryDirectory() as tmpdir:
previous_config = jax.config.read('jax_pallas_dump_promela_to')
jax.config.update('jax_pallas_dump_promela_to', tmpdir)
shard_map.shard_map(
kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False
)(jnp.ones((8, 128, 128), jnp.float32))
jax.config.update('jax_pallas_dump_promela_to', previous_config)
self.assertNotEmpty(os.listdir(tmpdir))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())