rocm_jax/tests/pallas/tpu_pallas_interpret_distributed_test.py
2025-03-05 15:15:21 -08:00

1101 lines
36 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TPU-specific interpret mode.
To work around https://github.com/jax-ml/jax/issues/25671 , this file
contains only tests that use shard_map.
"""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import test_util as jtu
import jax._src.pallas.mosaic.interpret as mosaic_interpret
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
P = jax.sharding.PartitionSpec
class InterpretDistributedTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jax.device_count() < 4:
self.skipTest(f'requires at least 4 devices, found {jax.device_count()}')
@parameterized.product(
dma_execution_mode=['eager', 'on_wait'],
detect_races=[True, False])
def test_right_permute_example(self, dma_execution_mode, detect_races):
num_devices = jax.device_count()
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(
jax.random.key(0), (8, 128 * num_devices), dtype=jnp.float32)
input_arr = jax.device_put(input_arr, sharding)
def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
my_id = lax.axis_index('x')
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH)
pltpu.semaphore_signal(
barrier_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH)
pltpu.semaphore_wait(barrier_sem, 2)
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=input_ref,
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
remote_copy_op.wait()
out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=(
# We allocate DMA semaphores in scratch memory.
[pltpu.SemaphoreType.DMA] * 2
),
)
right_permute = pl.pallas_call(
right_permute_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=pltpu.TPUCompilerParams(collective_id=13),
interpret=mosaic_interpret.TPUInterpretParams(
dma_execution_mode=dma_execution_mode, detect_races=detect_races),
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
shard_map.shard_map(
right_permute,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False,
)
)(input_arr)
# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))
xla_result = jax.jit(
shard_map.shard_map(
lambda x: lax.ppermute(x, 'x', perm),
mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result)
if detect_races:
self.assertFalse(mosaic_interpret.races.races_found)
@parameterized.product(
dma_execution_mode=['eager', 'on_wait'],
detect_races=[True, False])
def test_all_gather_example(self, dma_execution_mode, detect_races):
num_devices = jax.device_count()
partition = P('x', None)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(
jax.random.key(0), (8 * num_devices, 128), dtype=jnp.float32)
input_arr = jax.device_put(input_arr, sharding)
def all_gather_kernel(input_ref,
output_ref,
local_copy_sem,
send_sem,
recv_sems):
outer_step = pl.program_id(0)
my_id = lax.axis_index('x')
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
copy_slot = my_id - outer_step
copy_slot = lax.rem(copy_slot + num_devices, jnp.int32(num_devices))
@pl.when(outer_step == 0)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
local_copy_op = pltpu.make_async_copy(
src_ref=input_ref,
dst_ref=output_ref.at[my_id],
sem=local_copy_sem,
)
local_copy_op.start()
local_copy_op.wait()
# Copy to our right neighbor.
# Note that we will also be receiving data from our left neighbor,
# but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
# that the indices do not need to be symmetric between remote DMAs.
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=output_ref.at[copy_slot],
dst_ref=output_ref.at[copy_slot],
send_sem=send_sem,
recv_sem=recv_sems.at[outer_step],
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
remote_copy_op.wait()
out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=(
# DMA semaphores are allocated in scratch memory.
# We allocated one semaphore for a local HBM-VMEM copy,
# and one for the remote send semaphore.
[pltpu.SemaphoreType.DMA] * 2
# We additionally allocate one receive semaphore per device.
# This is to avoid situations where we have multiple
# DMAs in flight, as we do not want to share a receive
# semaphore between the DMAs.
+ [pltpu.SemaphoreType.DMA((num_devices-1,))]
),
grid=(num_devices-1,)
)
all_gather = pl.pallas_call(
all_gather_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(
dma_execution_mode=dma_execution_mode, detect_races=detect_races),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
shard_map.shard_map(
all_gather,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False
)
)(input_arr)
# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
shard_map.shard_map(
lambda x: lax.all_gather(x, 'x'),
mesh=mesh, in_specs=partition, out_specs=partition
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result)
if detect_races:
self.assertFalse(mosaic_interpret.races.races_found)
@parameterized.product(
dma_execution_mode=['eager', 'on_wait'],
detect_races=[True, False])
def test_all_reduce_sum_example(self, dma_execution_mode, detect_races):
num_devices = jax.device_count()
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
input_arr = jax.random.uniform(
jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)
def all_reduce_kernel(
x_ref,
o_ref,
hbm_scratch,
copy_sem,
remote_recv_sem,
remote_send_sem,
capacity_sem,
receive_scratch,
):
outer_step = pl.program_id(0)
working_slot = lax.rem(outer_step, jnp.int32(2))
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
left_neighbor = lax.rem(my_id - 1 + num_devices, jnp.int32(num_devices))
@pl.when(outer_step == 0)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
# Initialize o_ref, acc_scratch, and hbm_scratch.
o_ref[...] = jnp.zeros_like(o_ref)
receive_scratch[...] = jnp.zeros_like(receive_scratch)
initial_copy = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=hbm_scratch.at[working_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_copy.start()
initial_copy.wait()
# Signal to our left neighbor that we are ready to receive.
# Without this signal, our left neighbor can be >=1 iteration ahead,
# meaning it could write into our working slot.
pltpu.semaphore_signal(
capacity_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# Copy the partial result our left neighbor sent to us into VMEM for
# computation.
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot],
dst_ref=receive_scratch,
sem=copy_sem,
)
local_copy.start()
# Block until our right neighbor is ready to receive.
pltpu.semaphore_wait(capacity_sem, 1)
# Pass the value to our right neighbor.
remote_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot],
dst_ref=hbm_scratch.at[receiving_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy.start()
# Finish local copy and accumulate while remote_copy is happening.
local_copy.wait()
o_ref[...] += receive_scratch[...]
# Block until remote copy finishes.
remote_copy.wait()
out_shape = (
jax.ShapeDtypeStruct((8, 128), input_arr.dtype),
# We allocate the double-buffer as a Pallas output so that it is
# resident in HBM.
jax.ShapeDtypeStruct((2, 8, 128), input_arr.dtype), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# Our input lives in VMEM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
# Our output lives in VMEM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
# Our double-buffer lives in HBM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices,),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 3
+ [pltpu.SemaphoreType.REGULAR] # capacity_sem
+ [pltpu.VMEM((8, 128), input_arr.dtype)] # receive_scratch
),
)
kernel = pl.pallas_call(
all_reduce_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(
dma_execution_mode=dma_execution_mode, detect_races=detect_races),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
pallas_result = jax.jit(
shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]
def lax_sum(x):
return lax.psum(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
if detect_races:
self.assertFalse(mosaic_interpret.races.races_found)
@parameterized.product(
dma_execution_mode=['eager', 'on_wait'],
detect_races=[True, False])
def test_reduce_scatter_sum_example(self, dma_execution_mode, detect_races):
num_devices = jax.device_count()
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
jax.random.key(0),
shape=(block_size[0] * num_devices, block_size[1] * num_devices),
dtype=jnp.float32,
)
input_arr = jax.device_put(input_arr, sharding)
LEFT = 0
RIGHT = 1
def mod(x, n):
return lax.rem(x + n, n)
def signal(left_or_right, semaphore):
my_id = lax.axis_index('x')
if left_or_right == LEFT:
neighbor = mod(my_id - 1, jnp.int32(num_devices))
else:
neighbor = mod(my_id + 1, jnp.int32(num_devices))
pltpu.semaphore_signal(
semaphore,
inc=1,
device_id=(neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
def reduce_scatter_kernel(
x_ref,
o_ref,
hbm_scratch,
local_copy_sem,
left_recv_sem,
left_send_sem,
right_recv_sem,
right_send_sem,
left_capacity_sem,
right_capacity_sem,
accum_scratch,
):
outer_step = pl.program_id(0)
phase = pl.program_id(1)
is_start = jnp.logical_and(outer_step == 0, phase == 0)
last_iteration = outer_step == pl.num_programs(0) - 1
working_slot = lax.rem(outer_step, jnp.int32(2))
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = mod(my_id + 1, jnp.int32(num_devices))
left_neighbor = mod(my_id - 1, jnp.int32(num_devices))
left_copy_device = mod(my_id + outer_step + 1, jnp.int32(num_devices))
right_copy_device = mod(my_id - outer_step - 1, jnp.int32(num_devices))
# Slices can be specified using pl.ds(start, size)
left_copy_slice = pl.ds(0, block_size[0] // 2)
right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)
@pl.when(is_start)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
initial_left_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, left_copy_slice],
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_right_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
left_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
# Note: Right copy is flipped with regards to slots since we are copying
# to the next outer_step iteration.
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# --- Prologue ---
@pl.when(is_start)
def _():
# Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
o_ref[...] = jnp.zeros_like(o_ref[...])
accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
initial_left_copy.start()
initial_left_copy.wait()
initial_right_copy.start()
# We tell our left neighbor that it is allowed to send to the right.
# (and vice versa for right neighbor)
signal(LEFT, right_capacity_sem)
signal(RIGHT, left_capacity_sem)
# --- Body ---
# At the beginning of our kernel body, we start a DMA which copies
# the result we computed in the previous phase to our neighbor.
# This allows us to overlap the communication of sending our previous phase
# with the computation for the current phase.
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
# We block here until our right neighbor tells use we can send to
# the right.
pltpu.semaphore_wait(right_capacity_sem, 1)
right_copy.start()
@pl.when(phase == RIGHT)
def _():
# We block here until our left neighbor tells use we can send to
# the left.
pltpu.semaphore_wait(left_capacity_sem, 1)
left_copy.start()
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=accum_scratch,
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
@pl.when(phase == RIGHT)
def _():
accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
src_ref=accum_scratch,
dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(is_start)
def _():
initial_right_copy.wait()
# At the end of our kernel body, we wait on the DMA of the previous phase
# to make sure the results are ready for the next phase.
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
right_copy.wait()
signal(LEFT, right_capacity_sem)
@pl.when(phase == RIGHT)
def _():
left_copy.wait()
signal(RIGHT, left_capacity_sem)
# --- Epilogue ---
# Store result on last iteration.
@pl.when(last_iteration)
def _():
# Clean up semaphores so that they exit with a value of 0.
@pl.when(phase == LEFT)
def _():
o_ref[left_copy_slice, ...] = accum_scratch[...]
pltpu.semaphore_wait(right_capacity_sem, 1)
@pl.when(phase == RIGHT)
def _():
o_ref[right_copy_slice, ...] = accum_scratch[...]
pltpu.semaphore_wait(left_capacity_sem, 1)
out_shape = (
jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output
# Shape: [working/recv, block[0], block[1]]
jax.ShapeDtypeStruct(
(2, block_size[0], block_size[1]), jnp.float32
), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 5
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
+ [
pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
] # accum_scratch
),
)
def pallas_reduce_scatter(input_arr):
input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
return pl.pallas_call(
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(
dma_execution_mode=dma_execution_mode, detect_races=True),
compiler_params=pltpu.TPUCompilerParams(collective_id=7),
)(input_arr)[0]
pallas_result = jax.jit(
shard_map.shard_map(
pallas_reduce_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
x = x.reshape(num_devices, block_size[0], block_size[1])
return lax.psum_scatter(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_reduce_sum_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
if detect_races:
self.assertFalse(mosaic_interpret.races.races_found)
@parameterized.product(
dma_execution_mode=['eager', 'on_wait'],
detect_races=[True, False])
def test_reduce_scatter_sum_with_emit_pipeline_example(
self, dma_execution_mode, detect_races):
self.skipTest('requires a patched pallas.emit_pipeline to specify/fake '
'the TPU generation')
if jax.config.jax_enable_x64:
self.skipTest('pallas.emit_pipeline + x64 is not currently supported')
num_devices = jax.device_count()
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (512, 512)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
jax.random.key(0),
shape=(
outer_block_size[0] * num_devices,
outer_block_size[1] * num_devices,
),
dtype=jnp.float32,
)
input_arr = jax.device_put(input_arr, sharding)
inner_grid = (
outer_block_size[0] // inner_block_size[0] // 2,
outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
index_map=lambda i, j: (i, j),
block_shape=inner_block_size,
memory_space=pltpu.TPUMemorySpace.ANY,
)
LEFT = 0
RIGHT = 1
def mod(x, n):
return lax.rem(x + n, n)
def signal(left_or_right, semaphore):
my_id = lax.axis_index('x')
if left_or_right == LEFT:
neighbor = mod(my_id - 1, num_devices)
else:
neighbor = mod(my_id + 1, num_devices)
pltpu.semaphore_signal(
semaphore,
inc=1,
device_id=(neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
def reduce_scatter_kernel(
x_ref,
o_ref,
hbm_scratch,
left_recv_sem,
left_send_sem,
copy_sem,
right_recv_sem,
right_send_sem,
left_capacity_sem,
right_capacity_sem,
):
outer_step = pl.program_id(0)
phase = pl.program_id(1)
is_start = jnp.logical_and(outer_step == 0, phase == 0)
last_iteration = outer_step == pl.num_programs(0) - 1
working_slot = lax.rem(outer_step, 2)
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = mod(my_id + 1, num_devices)
left_neighbor = mod(my_id - 1, num_devices)
left_copy_device = mod(my_id + outer_step + 1, num_devices)
right_copy_device = mod(my_id - outer_step - 1, num_devices)
left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
current_phase_slice = pl.ds(
phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
)
initial_left_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, left_copy_slice],
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_right_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
left_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# --- Prologue ---
@pl.when(is_start)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
initial_left_copy.start()
initial_left_copy.wait()
initial_right_copy.start()
# We tell our left neighbor that it is allowed to send to the right.
# (and vice versa for right neighbor)
signal(LEFT, right_capacity_sem)
signal(RIGHT, left_capacity_sem)
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
# We block here until our right neighbor tells use we can send to
# the right.
pltpu.semaphore_wait(right_capacity_sem, 1)
right_copy.start()
@pl.when(phase == RIGHT)
def _():
# We block here until our left neighbor tells use we can send to
# the left.
pltpu.semaphore_wait(left_capacity_sem, 1)
left_copy.start()
# --- Body ---
def inner_kernel(input_ref, accum_ref):
# We do not explicitly use += because we set should_accumulate_out=True.
accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(
inner_kernel,
in_specs=[inner_block_spec],
out_specs=inner_block_spec,
should_accumulate_out=True,
grid=inner_grid,
)
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_pipeline(
x_ref.at[left_copy_device, left_copy_slice],
hbm_scratch.at[working_slot, left_copy_slice],
)
@pl.when(phase == RIGHT)
def _():
accum_pipeline(
x_ref.at[right_copy_device, right_copy_slice],
hbm_scratch.at[working_slot, right_copy_slice],
)
# --- Epilogue ---
@pl.when(is_start)
def _():
initial_right_copy.wait()
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
right_copy.wait()
signal(LEFT, right_capacity_sem)
@pl.when(phase == RIGHT)
def _():
left_copy.wait()
signal(RIGHT, left_capacity_sem)
# Store result on last iteration.
@pl.when(last_iteration)
def _():
output_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=o_ref.at[current_phase_slice],
sem=copy_sem,
)
output_copy.start()
output_copy.wait()
# Clean up semaphores so that they exit with a value of 0.
@pl.when(phase == LEFT)
def _():
pltpu.semaphore_wait(right_capacity_sem, 1)
@pl.when(phase == RIGHT)
def _():
pltpu.semaphore_wait(left_capacity_sem, 1)
out_shape = (
jax.ShapeDtypeStruct(
(outer_block_size[0], outer_block_size[1]), jnp.float32
),
# Shape: [working/recv, block[0], block[1]]
jax.ShapeDtypeStruct(
(2, outer_block_size[0], outer_block_size[1]), jnp.float32
), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 5
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
),
)
def pallas_reduce_scatter(input_arr):
input_arr = input_arr.reshape(
num_devices, outer_block_size[0], outer_block_size[1]
)
return pl.pallas_call(
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(
dma_execution_mode=dma_execution_mode, detect_races=detect_races),
compiler_params=pltpu.TPUCompilerParams(collective_id=19),
)(input_arr)[0]
pallas_result = jax.jit(
shard_map.shard_map(
pallas_reduce_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)
def lax_reduce_sum_scatter(x):
x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
return lax.psum_scatter(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_reduce_sum_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
if detect_races:
self.assertFalse(mosaic_interpret.races.races_found)
def test_race_detection(self):
num_devices = 4
mesh = jax.sharding.Mesh(np.array(jax.devices()[:4]), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x', None))
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)
def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem):
# Barrier with all devices before doing any DMAs.
barrier_sem = pltpu.get_barrier_semaphore()
@functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None)
def _(i, _):
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(jnp.int32(i),),
device_id_type=pltpu.DeviceIdType.MESH,
)
return None
pltpu.semaphore_wait(barrier_sem, num_devices)
# Send the specified DMAs.
my_id = lax.axis_index('x')
src_dst_ids = src_dst_ids_ref[:]
recv_count = 0
for i in range(src_dst_ids.shape[0]):
src_id = src_dst_ids[i, 0]
dst_id = src_dst_ids[i, 1]
@pl.when(src_id == my_id)
def _():
dma = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=o_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(dst_id,),
device_id_type=pltpu.DeviceIdType.MESH,
)
dma.start()
dma.wait_send()
recv_count += jnp.where(dst_id == my_id, 1, 0)
# Wait until we have received all DMAs.
@pl.when(recv_count > 0)
def _():
fake_dma = pltpu.make_async_remote_copy(
src_ref=x_ref.at[pl.ds(0, 8 * recv_count)],
dst_ref=o_ref.at[pl.ds(0, 8 * recv_count)],
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(my_id,),
device_id_type=pltpu.DeviceIdType.MESH,
)
fake_dma.wait_recv()
@jax.jit
def run(src_dst_ids):
return shard_map.shard_map(
pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype),
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA],
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
interpret=mosaic_interpret.TPUInterpretParams(
dma_execution_mode='eager',
detect_races=True,
),
),
mesh=mesh,
in_specs=(P(None), P('x', None)),
out_specs=P('x', None),
check_rep=False,
)(src_dst_ids, input_arr)
run(jnp.array([[0, 1], [1, 2], [2, 3]], jnp.int32)).block_until_ready()
self.assertFalse(mosaic_interpret.races.races_found)
# Racing writes to device 2.
run(jnp.array([[0, 1], [1, 2], [3, 2], [3, 0]], jnp.int32)).block_until_ready()
self.assertTrue(mosaic_interpret.races.races_found)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())