Start a new TPU interpret mode for Pallas.

The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.

The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.

The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:

 - Executing DMAs asynchronously.

 - Padding in pallas_call.

 - Propagating source info.
This commit is contained in:
Jacob Burnim 2024-11-22 10:49:17 -08:00
parent 5e915d3307
commit 1c82484c9b
7 changed files with 2040 additions and 6 deletions

View File

@ -652,6 +652,7 @@ pytype_strict_library(
"//jax/_src/pallas",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:helpers",
"//jax/_src/pallas/mosaic:interpret",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic:pipeline",

View File

@ -148,3 +148,15 @@ py_library(
"//jax/_src/pallas",
],
)
py_library(
name = "interpret",
srcs = ["interpret.py"],
deps = [
":core",
":primitives",
"//jax",
"//jax/_src/lib",
"//jax/_src/pallas",
] + py_deps("numpy"),
)

View File

@ -0,0 +1,909 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from collections.abc import Iterable, Sequence
import dataclasses
from functools import reduce
import math
import threading
from typing import Any
import jax
from jax import lax
from jax._src import callback
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src.pallas.mosaic import primitives as mosaic_primitives
from jax._src.pallas.mosaic import core as mosaic_core
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src import pjit
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
from jax._src.util import (
safe_map,
safe_zip,
split_list
)
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
import numpy as np
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Grid = pallas_core.Grid
TupleGrid = pallas_core.TupleGrid
GridSpec = pallas_core.GridSpec
BlockMapping = pallas_core.BlockMapping
GridMapping = pallas_core.GridMapping
BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
ScratchShapeTree = pallas_core.ScratchShapeTree
CostEstimate = pallas_core.CostEstimate
@dataclasses.dataclass(frozen=True)
class TPUInterpretParams:
pass
class Semaphore:
def __init__(self):
self.cv = threading.Condition()
# TODO(jburnim): Make this an array.
self.counts = collections.defaultdict(int)
def signal(self, inc, device_id):
device_id = tuple(int(x) for x in device_id)
with self.cv:
self.counts[device_id] += inc
self.cv.notify_all()
def wait(self, value, device_id):
device_id = tuple(int(x) for x in device_id)
with self.cv:
while self.counts[device_id] < value:
self.cv.wait()
self.counts[device_id] -= value
@dataclasses.dataclass(frozen=True)
class SharedMemory:
# (memory_space, buffer_id, device_id) -> NumPy array
# TODO(jburnim): Handle Megacore.
mem: dict = dataclasses.field(default_factory=dict)
# semaphore_id -> Semaphore
sem: dict = dataclasses.field(default_factory=dict)
lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
next_buffer_id: dict = dataclasses.field(
default_factory=lambda: collections.defaultdict(lambda: 100))
next_semaphore_id: dict = dataclasses.field(
default_factory=lambda: collections.defaultdict(lambda: 2000))
# TODO(jburnim): Do we want to support multiple instances of SharedMemory?
# Maybe for running multiple distinct interpreted computations in parallel?
_shared_memory = None
_shared_memory_init_lock = threading.Lock()
def _get_shared_memory() -> SharedMemory:
global _shared_memory
if _shared_memory is None:
with _shared_memory_init_lock:
if _shared_memory is None:
_shared_memory = SharedMemory()
return _shared_memory
def _clear_shared_memory():
global _shared_memory
with _shared_memory_init_lock:
_shared_memory = None
def _allocate_buffer(device_id, memory_space, val):
device_id = tuple(map(int, device_id))
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
val = np.array(val)
shared_memory = _get_shared_memory()
with shared_memory.lock:
buffer_id = shared_memory.next_buffer_id[device_id]
shared_memory.next_buffer_id[device_id] = buffer_id + 1
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
# TODO(jburnim): Raise an error if buffer_id is too big for int16.
return np.int16(buffer_id)
def _deallocate_buffer(device_id, memory_space, buffer_id):
device_id = tuple(map(int, device_id))
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
shared_memory = _get_shared_memory()
with shared_memory.lock:
# TODO(jburnim): Error if buffer doesn't exist?
shared_memory.mem.pop((memory_space, buffer_id, device_id), None)
def _allocate_semaphores(device_id, shape):
device_id = tuple(map(int, device_id))
shape = tuple(map(int, shape))
num_semaphores = math.prod(shape)
shared_memory = _get_shared_memory()
with shared_memory.lock:
semaphore_id = shared_memory.next_semaphore_id[device_id]
shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores
for i in range(semaphore_id, semaphore_id + num_semaphores):
if not i in shared_memory.sem:
shared_memory.sem[i] = Semaphore()
# NOTE: For now, we use a relatively uncommon datatype (int16) for
# semaphore (and buffer) IDs, so these values are more easily identifiable
# in kernels.
#
# TODO(jburnim): Raise an error if any IDs are too big for int16.
return np.int16(
range(semaphore_id, semaphore_id + num_semaphores)
).reshape(shape)
TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = {
v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)}
TPU_MEMORY_SPACE_NAMES = {
i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)}
# Default to VMEM when no memory space is specified.
TPU_MEMORY_SPACE_IDXS[None] = (
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM])
def get_barrier_semaphore(device_id, collective_id):
device_id = tuple(map(int, device_id))
collective_id = int(collective_id)
# TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict
# with IDs for regular or DMA semaphores. (For example, store them in a
# different table.)
shared_memory = _get_shared_memory()
with shared_memory.lock:
semaphore_id = collective_id
if not semaphore_id in shared_memory.sem:
shared_memory.sem[semaphore_id] = Semaphore()
return np.int16(semaphore_id)
def _transform_slice_or_index(slice_or_idx):
if isinstance(slice_or_idx, int):
return slice_or_idx
else:
start, size, stride = (
slice_or_idx.start, slice_or_idx.size, slice_or_idx.stride)
return slice(start, start + size * stride, stride)
def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
ret = []
i = 0
j = 0
while True:
if i == len(slice_or_idx1):
ret.extend(slice_or_idx2[j:])
return tuple(ret)
elif j == len(slice_or_idx2):
ret.extend(slice_or_idx1[i:])
return tuple(ret)
elif isinstance(slice_or_idx1[i], int):
ret.append(slice_or_idx1[i])
i += 1
elif isinstance(slice_or_idx2[j], int):
ret.append(slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step)
i += 1
j += 1
else:
ret.append(slice(
slice_or_idx1[i].start + slice_or_idx2[j].start * slice_or_idx1[i].step,
slice_or_idx1[i].start + slice_or_idx2[j].stop * slice_or_idx1[i].step,
slice_or_idx1[i].step * slice_or_idx2[j].step
))
i += 1
j += 1
def _to_range(transforms) -> tuple[slice | int, ...]:
ret = ()
for transform in transforms:
# For now, assume only NDIndexer transforms.
ret = _compose_slice_or_index(
ret, tuple(_transform_slice_or_index(i) for i in transform.indices))
return ret
def get(device_id, memory_space, buffer_id, transforms):
device_id = tuple(int(x) for x in device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
try:
transforms = jax.tree.map(int, transforms)
except:
raise ValueError('Advanced indexers are not supported on TPU')
shared_memory = _get_shared_memory()
with shared_memory.lock:
return shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
].copy()
def store(device_id, memory_space, buffer_id, transforms, val):
device_id = tuple(int(x) for x in device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
try:
transforms = jax.tree.map(int, transforms)
except:
raise ValueError('Advanced indexers are not supported on TPU')
val = np.array(val)
shared_memory = _get_shared_memory()
with shared_memory.lock:
if transforms:
shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
] = val
else:
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
def swap(device_id, memory_space, buffer_id, transforms, val):
device_id = tuple(int(x) for x in device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
try:
transforms = jax.tree.map(int, transforms)
except:
raise ValueError('Advanced indexers are not supported on TPU')
val = np.array(val)
shared_memory = _get_shared_memory()
with shared_memory.lock:
result = shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
].copy()
shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
] = val
return np.array(result)
def execute_dma(src, dst, send_sem, recv_sem):
# NOTE: `src` is a list of arguments for `get` (device_id, memory_space,
# buffer_id, transforms) and `dst` is a list of arguments for `store`
# (dst_device_id, dst_memory_space, dst_id, dst_transforms).
#
# TODO(jburnim): Clean this up.
# Do the read.
data = get(*src)
data_size = data.itemsize * data.size
# Signal the send semaphore.
if send_sem is not None:
send_sem.signal(data_size, device_id=src[0])
# Do the write.
store(*dst, data)
# Signal the receive semaphore.
recv_sem.signal(data_size, device_id=dst[0])
def print_memory(device_id):
device_id = tuple(map(int, device_id))
if all(d == 0 for d in device_id):
shared_memory = _get_shared_memory()
with shared_memory.lock:
print(shared_memory.mem)
def dma_start(device_id, src_memory_space, src_id, src_transforms,
dst_memory_space, dst_id, dst_transforms,
dst_sem,
src_sem,
dst_device_id):
device_id = tuple(int(x) for x in device_id)
src_memory_space, src_id = int(src_memory_space), int(src_id)
src_transforms = jax.tree.map(int, src_transforms)
dst_memory_space, dst_id = int(dst_memory_space), int(dst_id)
dst_transforms = jax.tree.map(int, dst_transforms)
dst_sem = int(dst_sem)
if src_sem is not None:
src_sem = int(src_sem)
if dst_device_id is not None:
dst_device_id = tuple(int(x) for x in dst_device_id)
else:
dst_device_id = device_id
shared_memory = _get_shared_memory()
with shared_memory.lock:
dst_sem = shared_memory.sem[dst_sem]
if src_sem is not None:
src_sem = shared_memory.sem[src_sem]
# For now, just execute the DMA immediately.
# TODO(jburnim): Execute DMAs asynchronously.
execute_dma(
(device_id, src_memory_space, src_id, src_transforms),
(dst_device_id, dst_memory_space, dst_id, dst_transforms),
src_sem,
dst_sem)
def dma_wait(device_id, sem, size):
device_id = tuple(int(x) for x in device_id)
sem = int(sem)
size = int(size)
shared_memory = _get_shared_memory()
with shared_memory.lock:
sem = shared_memory.sem[sem]
sem.wait(size, device_id)
def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index):
device_id = tuple(map(int, device_id))
sem = int(sem)
inc = int(inc)
target_device_id = tuple(map(int, target_device_id))
if target_core_index is not None:
raise NotImplementedError()
shared_memory = _get_shared_memory()
with shared_memory.lock:
sem = shared_memory.sem[sem]
sem.signal(inc, target_device_id)
def semaphore_wait(device_id, sem, value):
device_id = tuple(map(int, device_id))
sem = int(sem)
value = int(value)
shared_memory = _get_shared_memory()
with shared_memory.lock:
sem = shared_memory.sem[sem]
sem.wait(value, device_id)
def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
for transform in transforms:
if transform is None:
continue
shape = transform.transform_shape(shape)
dtype = transform.transform_dtype(dtype)
return shape, dtype
@lu.cache
def _to_jaxpr(flat_fun, in_avals):
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr
def _is_any(memory_space):
return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or
(memory_space == pallas_core.MemorySpace.ANY))
def _interpret_jaxpr(jaxpr, *args, compiler_params):
env = {}
def read(var):
if isinstance(var, jax_core.Literal):
return var.val
else:
return env[var]
def write(var, value):
env[var] = value
jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
# Get the mesh coordinates.
device_coords = tuple(
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
# TODO(jburnim): Convert to a single integer device ID.
# TODO(jburnim): Pass the device ID around, instead of re-fetching/computing
# it for each sub-jaxpr.
# TODO(jburnim): Clean up and finish this evaluation loop. For example:
# - Handle missing Pallas primitives, like masked_load.
# - Replace the big if-statement with a dictionary of rules.
# - Handle other higher-order primitives?
# - Megacore.
for eqn in jaxpr.eqns:
prim = eqn.primitive
invals = jax.util.safe_map(read, eqn.invars)
if prim is primitives.load_p:
raise NotImplementedError()
elif prim is primitives.swap_p:
raise NotImplementedError()
elif prim is lax.cond_p:
def _make_branch(jaxpr):
return lambda *args: _interpret_jaxpr(
jaxpr, *args, compiler_params=compiler_params)
out = lax.switch(
invals[0],
[_make_branch(branch_jaxpr.jaxpr)
for branch_jaxpr in eqn.params['branches']],
*invals[1:])
elif prim is lax.scan_p:
consts, init_carry, xs = split_list(
invals, [eqn.params['num_consts'], eqn.params['num_carry']])
def _scan_body(c, a):
return split_list(
_interpret_jaxpr(eqn.params['jaxpr'].jaxpr, *consts, *c, *a,
compiler_params=compiler_params),
[eqn.params['num_carry']])
carry, out = lax.scan(_scan_body, init_carry, xs=xs,
length=eqn.params.get('length', None))
out = carry + out
elif prim is lax.while_p:
cond_consts, body_consts, init_vals = split_list(
invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']])
out = lax.while_loop(
lambda args: _interpret_jaxpr(eqn.params['cond_jaxpr'].jaxpr,
*cond_consts, *args,
compiler_params=compiler_params)[0],
lambda args: _interpret_jaxpr(eqn.params['body_jaxpr'].jaxpr,
*body_consts, *args,
compiler_params=compiler_params),
init_vals)
elif prim is pjit.pjit_p:
pjit_jaxpr = eqn.params['jaxpr']
def f(*args):
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
compiler_params=compiler_params)
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals)
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
elif prim is primitives.run_scoped_p:
# Allocate a buffer or semaphore for each element of
# eqn.params['jaxpr'].invars .
allocs = []
for v in eqn.params['jaxpr'].invars:
if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
allocs.append(callback.io_callback(
_allocate_semaphores,
jax.ShapeDtypeStruct(v.aval.shape, jnp.int16),
device_coords,
v.aval.shape,
ordered=True))
else:
allocs.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
ordered=True))
out = _interpret_jaxpr(eqn.params['jaxpr'], *invals, *allocs,
compiler_params=compiler_params)
for a in allocs:
if isinstance(a, tuple):
callback.io_callback(
_deallocate_buffer,
None,
device_coords,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
a,
ordered=True)
else:
# TODO(jburnim): Delete semaphores.
# callback.io_callback(
# _deallocate_semaphores,
# None,
# device_coords,
# a,
# ordered=True)
pass
elif prim is state_primitives.get_p:
out = callback.io_callback(
get,
eqn.outvars[0].aval,
device_coords,
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
invals[0],
jax.tree.unflatten(eqn.params['tree'], invals[1:]),
ordered=True)
elif prim is state_primitives.swap_p:
out = callback.io_callback(
swap,
eqn.outvars[0].aval,
device_coords,
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
invals[0],
jax.tree.unflatten(eqn.params['tree'], invals[2:]),
invals[1],
ordered=True)
elif prim is mosaic_primitives.dma_start_p:
(src, src_transforms,
dst, dst_transforms,
dst_sem, dst_sem_transforms,
src_sem, src_sem_transforms,
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
(orig_src_ref, _, orig_dst_ref, *_
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
callback.io_callback(
dma_start,
(),
device_coords,
TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space],
src, src_transforms,
TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space],
dst, dst_transforms,
state_discharge.transform_array(dst_sem, dst_sem_transforms),
state_discharge.transform_array(src_sem, src_sem_transforms),
device_id,
ordered=True)
out = []
elif prim is mosaic_primitives.dma_wait_p:
(src, src_transforms,
dst, dst_transforms,
dst_sem, dst_sem_transforms,
src_sem, src_sem_transforms,
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
read_shape, read_dtype = _compute_transformed_shape_and_dtype(
eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms)
callback.io_callback(
dma_wait,
(),
device_coords,
state_discharge.transform_array(dst_sem, dst_sem_transforms),
math.prod(read_shape) * read_dtype.itemsize,
ordered=True)
out = []
elif prim is mosaic_primitives.get_barrier_semaphore_p:
out = callback.io_callback(
get_barrier_semaphore,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
compiler_params['mosaic']['collective_id'],
ordered=True)
elif prim is mosaic_primitives.semaphore_signal_p:
sem, sem_transforms, inc, device_id, core_index = (
jax.tree.unflatten(eqn.params['args_tree'], invals))
callback.io_callback(
semaphore_signal,
(),
device_coords,
state_discharge.transform_array(sem, sem_transforms),
inc,
device_id,
core_index,
ordered=True)
out = []
elif prim is mosaic_primitives.semaphore_wait_p:
sem, sem_transforms, value = (
jax.tree.unflatten(eqn.params['args_tree'], invals))
callback.io_callback(
semaphore_wait,
(),
device_coords,
state_discharge.transform_array(sem, sem_transforms),
value,
ordered=True)
out = []
else:
out = prim.bind(*invals, **eqn.params)
out = out if prim.multiple_results else [out]
jax.util.safe_map(write, eqn.outvars, out)
return jax.util.safe_map(read, jaxpr.outvars)
def _initialize_output_vals(
block_mappings_output: Iterable[BlockMapping],
input_args, input_output_aliases) -> Sequence[jax.Array]:
oi_map = {v: k for k, v in input_output_aliases}
output_vals = []
for i, bm in enumerate(block_mappings_output):
if i in oi_map:
output_vals.append(input_args[oi_map[i]])
else:
output_vals.append(primitives.uninitialized_value(
bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype))
return output_vals
def _compute_start_indices(block_mapping, loop_idx, *args):
block_indices = (
jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args))
if isinstance(block_mapping.indexing_mode, pallas_core.Blocked):
ret = tuple(i if b is pallas_core.mapped else b * i
for b, i in zip(block_mapping.block_shape, block_indices))
elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked):
ret = block_indices
else:
raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}")
return ret
def _get_next_indices(grid, indices):
next_indices = []
carry = True
for dim_size, index in reversed(list(zip(grid, indices))):
i = jnp.where(carry, index + 1, index)
carry = dim_size == i
next_indices.append(jnp.where(carry, 0, i))
return tuple(reversed(next_indices))
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx)
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims)
def get_interpret_effects():
return {callback._OrderedIOEffect}
def interpret_pallas_call(
*args,
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndSrcInfo,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
compiler_params: Any,
cost_estimate: CostEstimate,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del debug, cost_estimate, out_avals
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
dynamic_grid_args, scalars, input_args = split_list( # type: ignore
args,
[grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands],
)
# TODO(jburnim): Support dynamic grid sizes?
grid = grid_mapping.static_grid
device_coords = tuple(
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
# Allocate buffers in HBM for outputs.
io_alias_map = dict(input_output_aliases)
output_buffer_ids = []
output_vals = _initialize_output_vals(
grid_mapping.block_mappings_output, args, input_output_aliases)
for out_val in output_vals:
output_buffer_ids.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
out_val,
ordered=True))
# Allocate buffers for all kernel arguments (e.g., scalars, inputs,
# outputs, scratch).
kernel_buffer_ids = []
for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars):
kernel_buffer_ids.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM],
val,
ordered=True))
for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]):
output_idx = i - grid_mapping.num_inputs
is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs)
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
kernel_buffer_ids.append(callback.io_callback(
_allocate_semaphores,
jax.ShapeDtypeStruct(var.aval.shape, jnp.int16),
device_coords,
var.aval.shape,
ordered=True))
elif is_output and _is_any(var.aval.memory_space):
# Don't allocate a buffer -- use the already-allocated HBM output buffer.
#
# TODO(jburnim): For kernel args in HBM, check that block shape is the
# same as for the corresponding pallas_call input, and that the index_map
# is trivial.
kernel_buffer_ids.append(output_buffer_ids[output_idx])
elif (i < grid_mapping.num_inputs) and (i in io_alias_map):
# Instead of allocating a new buffer, use the already-allocated
# HBM output buffer.
assert _is_any(var.aval.memory_space)
kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]])
else:
kernel_buffer_ids.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
ordered=True))
num_inputs = grid_mapping.num_inputs
_, input_ids, kernel_output_ids, _ = split_list(
kernel_buffer_ids,
[grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs])
input_vars, output_vars = split_list(
jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs])
# For kernel inputs that are in HBM, we populate the buffer once before
# any kernel invocations.
for buffer_id, var, val in zip(input_ids, input_vars, input_args):
if not _is_any(var.aval.memory_space):
continue
if val.shape != var.aval.shape:
# TODO(jburnim): Also check that the index_map is trivial.
raise ValueError()
callback.io_callback(
store,
(),
device_coords,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
buffer_id,
(),
val,
ordered=True)
is_indexing_dim = [
tuple(b is pallas_core.mapped for b in bm.block_shape)
for bm in grid_mapping.block_mappings
]
block_shapes = [
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
]
scalar_ids, in_out_ids, scratch_ids = split_list(
kernel_buffer_ids,
[grid_mapping.num_index_operands, len(grid_mapping.block_mappings)])
if grid:
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
else:
# Base case is always one iteration when grid is ()
num_iterations = 1
def body(carry):
# The loop carry: (i, loop_idx) --
# - i:int32 is the interation index
# - loop_idx: tuple[int32] are the program ids for each grid axis
i, loop_idx = carry
if grid_mapping.local_grid_env is not None:
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
else:
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.vmapped_dims
)
with pallas_core.grid_env(local_grid_env):
# Copy slices of the input to the kernel buffers.
#
# TODO(jburnim): Only copy slices when the index mapping has changed?
start_indices = [_compute_start_indices(bm, loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
for j, var in enumerate(input_vars):
if _is_any(var.aval.memory_space):
continue
sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j],
input_args[j], is_indexing_dim[j])
assert(sliced_val.shape == var.aval.shape)
callback.io_callback(
store,
(),
device_coords,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
input_ids[j],
(),
sliced_val,
ordered=True)
# Invoke the kernel.
_interpret_jaxpr(jaxpr, *kernel_buffer_ids,
compiler_params=compiler_params)
# Copy from the kernel buffers to slices of the output in HBM.
#
# TODO(jburnim): Only copy if the index mapping will change in the
# next iteration (or if this is the last iteration)?
for j, var in enumerate(output_vars):
if _is_any(var.aval.memory_space):
continue
kernel_output_val = callback.io_callback(
get,
var.aval,
device_coords,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
kernel_output_ids[j],
(),
ordered=True)
transform = indexing.NDIndexer(
indices=tuple(indexing.ds(st, sz)
for st, sz in zip(start_indices[num_inputs + j],
block_shapes[num_inputs + j])),
shape=output_vals[j].shape,
int_indexer_shape=())
callback.io_callback(
store,
(),
device_coords,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
output_buffer_ids[j],
(transform,),
kernel_output_val,
ordered=True)
return i + 1, _get_next_indices(grid, loop_idx)
# TODO(jburnim): Handle parallel grid dimensions + megacore.
_ = lax.while_loop(
lambda carry: carry[0] < num_iterations,
body,
(jnp.int32(0), (jnp.int32(0),) * len(grid))
)
# Read the output from the allocated output buffers.
ret = [
callback.io_callback(
get,
val,
device_coords,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
output_buffer_id,
(),
ordered=True)
for val, output_buffer_id in zip(output_vals, output_buffer_ids)
]
for buffer_id in output_buffer_ids:
callback.io_callback(
_deallocate_buffer,
(),
device_coords,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
buffer_id,
ordered=True)
for buffer_id, var in zip(kernel_buffer_ids, jaxpr.invars):
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
pass
else:
callback.io_callback(
_deallocate_buffer,
(),
device_coords,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
buffer_id,
ordered=True)
return ret

View File

@ -17,7 +17,9 @@ from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
import enum
from functools import partial, reduce
import types
from typing import Any, Literal
import jax
@ -82,19 +84,31 @@ pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(
*avals, out_avals: tuple[jax_core.AbstractValue, ...], **_
*avals,
out_avals: tuple[jax_core.AbstractValue, ...],
interpret,
backend,
**params
):
del avals
if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams):
# Report effects that will be introduced when running/lowering
# mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call .
effs = mosaic_tpu_interpret.get_interpret_effects()
else:
effs = jax_core.no_effects
# Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
return [
jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
if isinstance(a, pallas_core.ShapedArrayWithMemorySpace)
else a
for a in out_avals
]
], effs
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
pallas_call_p.def_effectful_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(
@ -1230,9 +1244,12 @@ def _pallas_call_lowering(
if params['jaxpr'].constvars:
raise ValueError('Cannot lower a pallas_call with constants.')
if interpret:
impl = partial(hlo_interpreter.pallas_call_hlo_interpret,
backend=backend,
**params)
if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams):
impl = partial(mosaic_tpu_interpret.interpret_pallas_call, **params)
else:
impl = partial(hlo_interpreter.pallas_call_hlo_interpret,
backend=backend,
**params)
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
def cpu_lowering(ctx: mlir.LoweringRuleContext,
@ -1681,3 +1698,10 @@ try:
from jax._src.pallas.mosaic import pallas_call_registration as mosaic_tpu_backend
except ImportError:
mosaic_tpu_backend = None # type: ignore
try:
from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret
except ImportError:
mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore
TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)),
)

View File

@ -426,6 +426,32 @@ jax_multiplatform_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_multiplatform_test(
name = "tpu_pallas_interpret_test",
srcs = [
"tpu_pallas_interpret_test.py",
],
disable_configs = ["cpu_shardy"],
enable_backends = ["cpu"],
deps = [
"//jax:pallas",
"//jax:pallas_tpu",
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_multiplatform_test(
name = "tpu_pallas_interpret_distributed_test",
srcs = [
"tpu_pallas_interpret_distributed_test.py",
],
disable_configs = ["cpu_shardy"],
enable_backends = ["cpu"],
deps = [
"//third_party/py/jax:pallas",
"//third_party/py/jax:pallas_tpu",
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_multiplatform_test(
name = "tpu_paged_attention_kernel_test",
srcs = ["tpu_paged_attention_kernel_test.py"],

View File

@ -0,0 +1,995 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TPU-specific interpret mode.
To work around https://github.com/jax-ml/jax/issues/25671 , this file
contains only tests that use shard_map.
"""
from absl.testing import absltest
import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
import jax._src.pallas.mosaic.interpret as mosaic_interpret
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
P = jax.sharding.PartitionSpec
class InterpretDistributedTest(jtu.JaxTestCase):
def test_right_permute_example(self):
num_devices = jax.device_count()
if num_devices < 4:
self.skipTest(f'requires at least 4 devices, found {num_devices}')
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(
jax.random.key(0), (8, 128 * num_devices), dtype=jnp.float32)
input_arr = jax.device_put(input_arr, sharding)
def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
my_id = lax.axis_index('x')
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
barrier_sem = pltpu.get_barrier_semaphore()
def _body(ijk):
i, (j, k) = ijk
lax.cond(
(i == 0) | (j == 0),
lambda: pltpu.semaphore_signal(
barrier_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH),
lambda: pltpu.semaphore_signal(
barrier_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH))
return (i + 1, (j + 1, k + 1))
lax.while_loop(lambda ijk: ijk[0] < 2, _body, (0, (0, 0)))
pltpu.semaphore_wait(barrier_sem, 2)
def _body2(i, a):
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=input_ref,
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
remote_copy_op.wait()
return i + 1, a + 1
_ = lax.scan(_body2, 0, jnp.arange(4.0), unroll=2)
out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=(
# We allocate DMA semaphores in scratch memory.
[pltpu.SemaphoreType.DMA] * 2
),
)
right_permute = pl.pallas_call(
right_permute_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=pltpu.TPUCompilerParams(collective_id=13),
interpret=mosaic_interpret.TPUInterpretParams(),
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
shard_map.shard_map(
right_permute,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False,
)
)(input_arr)
# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))
xla_result = jax.jit(
shard_map.shard_map(
lambda x: lax.ppermute(x, 'x', perm),
mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result)
def test_all_gather_example(self):
num_devices = jax.device_count()
if num_devices < 4:
self.skipTest(f'requires at least 4 devices, found {num_devices}')
partition = P('x', None)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)
def all_gather_kernel(input_ref,
output_ref,
local_copy_sem,
send_sem,
recv_sems):
outer_step = pl.program_id(0)
my_id = lax.axis_index('x')
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
copy_slot = my_id - outer_step
copy_slot = lax.rem(copy_slot + num_devices, jnp.int32(num_devices))
@pl.when(outer_step == 0)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
local_copy_op = pltpu.make_async_copy(
src_ref=input_ref,
dst_ref=output_ref.at[my_id],
sem=local_copy_sem,
)
local_copy_op.start()
local_copy_op.wait()
# Copy to our right neighbor.
# Note that we will also be receiving data from our left neighbor,
# but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
# that the indices do not need to be symmetric between remote DMAs.
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=output_ref.at[copy_slot],
dst_ref=output_ref.at[copy_slot],
send_sem=send_sem,
recv_sem=recv_sems.at[outer_step],
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
remote_copy_op.wait()
out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=(
# DMA semaphores are allocated in scratch memory.
# We allocated one semaphore for a local HBM-VMEM copy,
# and one for the remote send semaphore.
[pltpu.SemaphoreType.DMA] * 2
# We additionally allocate one receive semaphore per device.
# This is to avoid situations where we have multiple
# DMAs in flight, as we do not want to share a receive
# semaphore between the DMAs.
+ [pltpu.SemaphoreType.DMA((num_devices-1,))]
),
grid=(num_devices-1,)
)
all_gather = pl.pallas_call(
all_gather_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
shard_map.shard_map(
all_gather,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False
)
)(input_arr)
# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
shard_map.shard_map(
lambda x: lax.all_gather(x, 'x'),
mesh=mesh, in_specs=partition, out_specs=partition
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result)
def test_all_reduce_sum_example(self):
num_devices = jax.device_count()
if num_devices < 4:
self.skipTest(f'requires at least 4 devices, found {num_devices}')
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
input_arr = jax.random.uniform(
jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)
def all_reduce_kernel(
x_ref,
o_ref,
hbm_scratch,
copy_sem,
remote_recv_sem,
remote_send_sem,
capacity_sem,
receive_scratch,
):
outer_step = pl.program_id(0)
working_slot = lax.rem(outer_step, jnp.int32(2))
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
left_neighbor = lax.rem(my_id - 1 + num_devices, jnp.int32(num_devices))
@pl.when(outer_step == 0)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
# Initialize o_ref, acc_scratch, and hbm_scratch.
o_ref[...] = jnp.zeros_like(o_ref)
receive_scratch[...] = jnp.zeros_like(receive_scratch)
initial_copy = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=hbm_scratch.at[working_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_copy.start()
initial_copy.wait()
# Signal to our left neighbor that we are ready to receive.
# Without this signal, our left neighbor can be >=1 iteration ahead,
# meaning it could write into our working slot.
pltpu.semaphore_signal(
capacity_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# Copy the partial result our left neighbor sent to us into VMEM for
# computation.
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot],
dst_ref=receive_scratch,
sem=copy_sem,
)
local_copy.start()
# Block until our right neighbor is ready to receive.
pltpu.semaphore_wait(capacity_sem, 1)
# Pass the value to our right neighbor.
remote_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot],
dst_ref=hbm_scratch.at[receiving_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy.start()
# Finish local copy and accumulate while remote_copy is happening.
local_copy.wait()
o_ref[...] += receive_scratch[...]
# Block until remote copy finishes.
remote_copy.wait()
out_shape = (
jax.ShapeDtypeStruct((8, 128), jnp.float32),
# We allocate the double-buffer as a Pallas output so that it is
# resident in HBM.
jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# Our input lives in VMEM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
# Our output lives in VMEM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
# Our double-buffer lives in HBM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices,),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 3
+ [pltpu.SemaphoreType.REGULAR] # capacity_sem
+ [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch
),
)
kernel = pl.pallas_call(
all_reduce_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
pallas_result = jax.jit(
shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]
def lax_sum(x):
return lax.psum(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
def test_reduce_scatter_sum_example(self):
num_devices = jax.device_count()
if num_devices < 4:
self.skipTest(f'requires at least 4 devices, found {num_devices}')
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
jax.random.key(0),
shape=(block_size[0] * num_devices, block_size[1] * num_devices),
dtype=jnp.float32,
)
input_arr = jax.device_put(input_arr, sharding)
LEFT = 0
RIGHT = 1
def mod(x, n):
return lax.rem(x + n, n)
def signal(left_or_right, semaphore):
my_id = lax.axis_index('x')
if left_or_right == LEFT:
neighbor = mod(my_id - 1, jnp.int32(num_devices))
else:
neighbor = mod(my_id + 1, jnp.int32(num_devices))
pltpu.semaphore_signal(
semaphore,
inc=1,
device_id=(neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
def reduce_scatter_kernel(
x_ref,
o_ref,
hbm_scratch,
local_copy_sem,
left_recv_sem,
left_send_sem,
right_recv_sem,
right_send_sem,
left_capacity_sem,
right_capacity_sem,
accum_scratch,
):
outer_step = pl.program_id(0)
phase = pl.program_id(1)
is_start = jnp.logical_and(outer_step == 0, phase == 0)
last_iteration = outer_step == pl.num_programs(0) - 1
working_slot = lax.rem(outer_step, jnp.int32(2))
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = mod(my_id + 1, jnp.int32(num_devices))
left_neighbor = mod(my_id - 1, jnp.int32(num_devices))
left_copy_device = mod(my_id + outer_step + 1, jnp.int32(num_devices))
right_copy_device = mod(my_id - outer_step - 1, jnp.int32(num_devices))
# Slices can be specified using pl.ds(start, size)
left_copy_slice = pl.ds(0, block_size[0] // 2)
right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)
@pl.when(is_start)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
initial_left_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, left_copy_slice],
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_right_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
left_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
# Note: Right copy is flipped with regards to slots since we are copying
# to the next outer_step iteration.
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# --- Prologue ---
@pl.when(is_start)
def _():
# Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
o_ref[...] = jnp.zeros_like(o_ref[...])
accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
initial_left_copy.start()
initial_left_copy.wait()
initial_right_copy.start()
# We tell our left neighbor that it is allowed to send to the right.
# (and vice versa for right neighbor)
signal(LEFT, right_capacity_sem)
signal(RIGHT, left_capacity_sem)
# --- Body ---
# At the beginning of our kernel body, we start a DMA which copies
# the result we computed in the previous phase to our neighbor.
# This allows us to overlap the communication of sending our previous phase
# with the computation for the current phase.
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
# We block here until our right neighbor tells use we can send to
# the right.
pltpu.semaphore_wait(right_capacity_sem, 1)
right_copy.start()
@pl.when(phase == RIGHT)
def _():
# We block here until our left neighbor tells use we can send to
# the left.
pltpu.semaphore_wait(left_capacity_sem, 1)
left_copy.start()
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=accum_scratch,
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
@pl.when(phase == RIGHT)
def _():
accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
src_ref=accum_scratch,
dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(is_start)
def _():
initial_right_copy.wait()
# At the end of our kernel body, we wait on the DMA of the previous phase
# to make sure the results are ready for the next phase.
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
right_copy.wait()
signal(LEFT, right_capacity_sem)
@pl.when(phase == RIGHT)
def _():
left_copy.wait()
signal(RIGHT, left_capacity_sem)
# --- Epilogue ---
# Store result on last iteration.
@pl.when(last_iteration)
def _():
# Clean up semaphores so that they exit with a value of 0.
@pl.when(phase == LEFT)
def _():
o_ref[left_copy_slice, ...] = accum_scratch[...]
pltpu.semaphore_wait(right_capacity_sem, 1)
@pl.when(phase == RIGHT)
def _():
o_ref[right_copy_slice, ...] = accum_scratch[...]
pltpu.semaphore_wait(left_capacity_sem, 1)
out_shape = (
jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output
# Shape: [working/recv, block[0], block[1]]
jax.ShapeDtypeStruct(
(2, block_size[0], block_size[1]), jnp.float32
), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 5
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
+ [
pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
] # accum_scratch
),
)
def pallas_reduce_scatter(input_arr):
input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
return pl.pallas_call(
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(),
compiler_params=pltpu.TPUCompilerParams(collective_id=7),
)(input_arr)[0]
pallas_result = jax.jit(
shard_map.shard_map(
pallas_reduce_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
x = x.reshape(num_devices, block_size[0], block_size[1])
return lax.psum_scatter(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_reduce_sum_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
def test_reduce_scatter_sum_with_emit_pipeline_example(self):
self.skipTest('requires a patched pallas.emit_pipeline to specify/fake '
'the TPU generation')
if jax.config.jax_enable_x64:
self.skipTest('pallas.emit_pipeline + x64 is not currently supported')
num_devices = jax.device_count()
if num_devices < 4:
self.skipTest(f'requires at least 4 devices, found {num_devices}')
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (512, 512)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
jax.random.key(0),
shape=(
outer_block_size[0] * num_devices,
outer_block_size[1] * num_devices,
),
)
input_arr = jax.device_put(input_arr, sharding)
inner_grid = (
outer_block_size[0] // inner_block_size[0] // 2,
outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
index_map=lambda i, j: (i, j),
block_shape=inner_block_size,
memory_space=pltpu.TPUMemorySpace.ANY,
)
LEFT = 0
RIGHT = 1
def mod(x, n):
return lax.rem(x + n, n)
def signal(left_or_right, semaphore):
my_id = lax.axis_index('x')
if left_or_right == LEFT:
neighbor = mod(my_id - 1, num_devices)
else:
neighbor = mod(my_id + 1, num_devices)
pltpu.semaphore_signal(
semaphore,
inc=1,
device_id=(neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
def reduce_scatter_kernel(
x_ref,
o_ref,
hbm_scratch,
left_recv_sem,
left_send_sem,
copy_sem,
right_recv_sem,
right_send_sem,
left_capacity_sem,
right_capacity_sem,
):
outer_step = pl.program_id(0)
phase = pl.program_id(1)
is_start = jnp.logical_and(outer_step == 0, phase == 0)
last_iteration = outer_step == pl.num_programs(0) - 1
working_slot = lax.rem(outer_step, 2)
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = mod(my_id + 1, num_devices)
left_neighbor = mod(my_id - 1, num_devices)
left_copy_device = mod(my_id + outer_step + 1, num_devices)
right_copy_device = mod(my_id - outer_step - 1, num_devices)
left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
current_phase_slice = pl.ds(
phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
)
initial_left_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, left_copy_slice],
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_right_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
left_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# --- Prologue ---
@pl.when(is_start)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
initial_left_copy.start()
initial_left_copy.wait()
initial_right_copy.start()
# We tell our left neighbor that it is allowed to send to the right.
# (and vice versa for right neighbor)
signal(LEFT, right_capacity_sem)
signal(RIGHT, left_capacity_sem)
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
# We block here until our right neighbor tells use we can send to
# the right.
pltpu.semaphore_wait(right_capacity_sem, 1)
right_copy.start()
@pl.when(phase == RIGHT)
def _():
# We block here until our left neighbor tells use we can send to
# the left.
pltpu.semaphore_wait(left_capacity_sem, 1)
left_copy.start()
# --- Body ---
def inner_kernel(input_ref, accum_ref):
# We do not explicitly use += because we set should_accumulate_out=True.
accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(
inner_kernel,
in_specs=[inner_block_spec],
out_specs=inner_block_spec,
should_accumulate_out=True,
grid=inner_grid,
)
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_pipeline(
x_ref.at[left_copy_device, left_copy_slice],
hbm_scratch.at[working_slot, left_copy_slice],
)
@pl.when(phase == RIGHT)
def _():
accum_pipeline(
x_ref.at[right_copy_device, right_copy_slice],
hbm_scratch.at[working_slot, right_copy_slice],
)
# --- Epilogue ---
@pl.when(is_start)
def _():
initial_right_copy.wait()
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
right_copy.wait()
signal(LEFT, right_capacity_sem)
@pl.when(phase == RIGHT)
def _():
left_copy.wait()
signal(RIGHT, left_capacity_sem)
# Store result on last iteration.
@pl.when(last_iteration)
def _():
output_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=o_ref.at[current_phase_slice],
sem=copy_sem,
)
output_copy.start()
output_copy.wait()
# Clean up semaphores so that they exit with a value of 0.
@pl.when(phase == LEFT)
def _():
pltpu.semaphore_wait(right_capacity_sem, 1)
@pl.when(phase == RIGHT)
def _():
pltpu.semaphore_wait(left_capacity_sem, 1)
out_shape = (
jax.ShapeDtypeStruct(
(outer_block_size[0], outer_block_size[1]), jnp.float32
),
# Shape: [working/recv, block[0], block[1]]
jax.ShapeDtypeStruct(
(2, outer_block_size[0], outer_block_size[1]), jnp.float32
), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 5
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
),
)
def pallas_reduce_scatter(input_arr):
input_arr = input_arr.reshape(
num_devices, outer_block_size[0], outer_block_size[1]
)
return pl.pallas_call(
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=mosaic_interpret.TPUInterpretParams(),
compiler_params=pltpu.TPUCompilerParams(collective_id=19),
)(input_arr)[0]
pallas_result = jax.jit(
shard_map.shard_map(
pallas_reduce_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)
def lax_reduce_sum_scatter(x):
x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
return lax.psum_scatter(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_reduce_sum_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
)
)(input_arr)
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -0,0 +1,67 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TPU-specific interpret mode.
To work around https://github.com/jax-ml/jax/issues/25671 , this file
contains only tests that do not use shard_map.
"""
from absl.testing import absltest
import numpy as np
import jax
from jax._src import test_util as jtu
import jax._src.pallas.mosaic.interpret as mosaic_interpret
from jax.experimental import pallas as pl
jax.config.parse_flags_with_absl()
class InterpretTest(jtu.JaxTestCase):
def test_matmul_example(self):
num_devices = jax.device_count()
if num_devices > 1:
# Workaround for https://github.com/jax-ml/jax/issues/25671
self.skipTest(f'requires 1 device, found {num_devices}')
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
@jax.jit
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
),
interpret=mosaic_interpret.TPUInterpretParams(),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y, atol=1e-4)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())