mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU, while simulating a TPU's shared memory, multiple devices/cores, remote DMAs, and synchronization. The basic approach is to execute the kernel's Jaxpr on CPU, but to replace all load/store, DMA, and synchronization primitives with io_callbacks to a Python functions that simulate these primitives. When this interpret mode is run inside of shard_map and jit, the shards will run in parallel, simulating the parallel execution of the kernel on multiple TPU devices. The initial version in this PR can successfully interpret the examples in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html , but is still missing a lot of functionality, including: - Executing DMAs asynchronously. - Padding in pallas_call. - Propagating source info.
This commit is contained in:
parent
5e915d3307
commit
1c82484c9b
@ -652,6 +652,7 @@ pytype_strict_library(
|
||||
"//jax/_src/pallas",
|
||||
"//jax/_src/pallas/mosaic:core",
|
||||
"//jax/_src/pallas/mosaic:helpers",
|
||||
"//jax/_src/pallas/mosaic:interpret",
|
||||
"//jax/_src/pallas/mosaic:lowering",
|
||||
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/mosaic:pipeline",
|
||||
|
@ -148,3 +148,15 @@ py_library(
|
||||
"//jax/_src/pallas",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "interpret",
|
||||
srcs = ["interpret.py"],
|
||||
deps = [
|
||||
":core",
|
||||
":primitives",
|
||||
"//jax",
|
||||
"//jax/_src/lib",
|
||||
"//jax/_src/pallas",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
909
jax/_src/pallas/mosaic/interpret.py
Normal file
909
jax/_src/pallas/mosaic/interpret.py
Normal file
@ -0,0 +1,909 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from collections.abc import Iterable, Sequence
|
||||
import dataclasses
|
||||
from functools import reduce
|
||||
import math
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import callback
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.pallas.mosaic import primitives as mosaic_primitives
|
||||
from jax._src.pallas.mosaic import core as mosaic_core
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src import pjit
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.util import (
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_list
|
||||
)
|
||||
from jax.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Grid = pallas_core.Grid
|
||||
TupleGrid = pallas_core.TupleGrid
|
||||
GridSpec = pallas_core.GridSpec
|
||||
BlockMapping = pallas_core.BlockMapping
|
||||
GridMapping = pallas_core.GridMapping
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
ScratchShapeTree = pallas_core.ScratchShapeTree
|
||||
CostEstimate = pallas_core.CostEstimate
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TPUInterpretParams:
|
||||
pass
|
||||
|
||||
|
||||
class Semaphore:
|
||||
def __init__(self):
|
||||
self.cv = threading.Condition()
|
||||
|
||||
# TODO(jburnim): Make this an array.
|
||||
self.counts = collections.defaultdict(int)
|
||||
|
||||
def signal(self, inc, device_id):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
with self.cv:
|
||||
self.counts[device_id] += inc
|
||||
self.cv.notify_all()
|
||||
|
||||
def wait(self, value, device_id):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
with self.cv:
|
||||
while self.counts[device_id] < value:
|
||||
self.cv.wait()
|
||||
self.counts[device_id] -= value
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SharedMemory:
|
||||
# (memory_space, buffer_id, device_id) -> NumPy array
|
||||
# TODO(jburnim): Handle Megacore.
|
||||
mem: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
# semaphore_id -> Semaphore
|
||||
sem: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
|
||||
|
||||
next_buffer_id: dict = dataclasses.field(
|
||||
default_factory=lambda: collections.defaultdict(lambda: 100))
|
||||
next_semaphore_id: dict = dataclasses.field(
|
||||
default_factory=lambda: collections.defaultdict(lambda: 2000))
|
||||
|
||||
# TODO(jburnim): Do we want to support multiple instances of SharedMemory?
|
||||
# Maybe for running multiple distinct interpreted computations in parallel?
|
||||
_shared_memory = None
|
||||
_shared_memory_init_lock = threading.Lock()
|
||||
|
||||
def _get_shared_memory() -> SharedMemory:
|
||||
global _shared_memory
|
||||
if _shared_memory is None:
|
||||
with _shared_memory_init_lock:
|
||||
if _shared_memory is None:
|
||||
_shared_memory = SharedMemory()
|
||||
return _shared_memory
|
||||
|
||||
def _clear_shared_memory():
|
||||
global _shared_memory
|
||||
with _shared_memory_init_lock:
|
||||
_shared_memory = None
|
||||
|
||||
def _allocate_buffer(device_id, memory_space, val):
|
||||
device_id = tuple(map(int, device_id))
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
val = np.array(val)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
buffer_id = shared_memory.next_buffer_id[device_id]
|
||||
shared_memory.next_buffer_id[device_id] = buffer_id + 1
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
|
||||
|
||||
# TODO(jburnim): Raise an error if buffer_id is too big for int16.
|
||||
return np.int16(buffer_id)
|
||||
|
||||
def _deallocate_buffer(device_id, memory_space, buffer_id):
|
||||
device_id = tuple(map(int, device_id))
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
# TODO(jburnim): Error if buffer doesn't exist?
|
||||
shared_memory.mem.pop((memory_space, buffer_id, device_id), None)
|
||||
|
||||
def _allocate_semaphores(device_id, shape):
|
||||
device_id = tuple(map(int, device_id))
|
||||
shape = tuple(map(int, shape))
|
||||
num_semaphores = math.prod(shape)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
semaphore_id = shared_memory.next_semaphore_id[device_id]
|
||||
shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores
|
||||
for i in range(semaphore_id, semaphore_id + num_semaphores):
|
||||
if not i in shared_memory.sem:
|
||||
shared_memory.sem[i] = Semaphore()
|
||||
|
||||
# NOTE: For now, we use a relatively uncommon datatype (int16) for
|
||||
# semaphore (and buffer) IDs, so these values are more easily identifiable
|
||||
# in kernels.
|
||||
#
|
||||
# TODO(jburnim): Raise an error if any IDs are too big for int16.
|
||||
return np.int16(
|
||||
range(semaphore_id, semaphore_id + num_semaphores)
|
||||
).reshape(shape)
|
||||
|
||||
|
||||
TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = {
|
||||
v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)}
|
||||
TPU_MEMORY_SPACE_NAMES = {
|
||||
i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)}
|
||||
|
||||
# Default to VMEM when no memory space is specified.
|
||||
TPU_MEMORY_SPACE_IDXS[None] = (
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM])
|
||||
|
||||
def get_barrier_semaphore(device_id, collective_id):
|
||||
device_id = tuple(map(int, device_id))
|
||||
collective_id = int(collective_id)
|
||||
|
||||
# TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict
|
||||
# with IDs for regular or DMA semaphores. (For example, store them in a
|
||||
# different table.)
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
semaphore_id = collective_id
|
||||
if not semaphore_id in shared_memory.sem:
|
||||
shared_memory.sem[semaphore_id] = Semaphore()
|
||||
|
||||
return np.int16(semaphore_id)
|
||||
|
||||
def _transform_slice_or_index(slice_or_idx):
|
||||
if isinstance(slice_or_idx, int):
|
||||
return slice_or_idx
|
||||
else:
|
||||
start, size, stride = (
|
||||
slice_or_idx.start, slice_or_idx.size, slice_or_idx.stride)
|
||||
return slice(start, start + size * stride, stride)
|
||||
|
||||
def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
|
||||
ret = []
|
||||
i = 0
|
||||
j = 0
|
||||
while True:
|
||||
if i == len(slice_or_idx1):
|
||||
ret.extend(slice_or_idx2[j:])
|
||||
return tuple(ret)
|
||||
elif j == len(slice_or_idx2):
|
||||
ret.extend(slice_or_idx1[i:])
|
||||
return tuple(ret)
|
||||
elif isinstance(slice_or_idx1[i], int):
|
||||
ret.append(slice_or_idx1[i])
|
||||
i += 1
|
||||
elif isinstance(slice_or_idx2[j], int):
|
||||
ret.append(slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step)
|
||||
i += 1
|
||||
j += 1
|
||||
else:
|
||||
ret.append(slice(
|
||||
slice_or_idx1[i].start + slice_or_idx2[j].start * slice_or_idx1[i].step,
|
||||
slice_or_idx1[i].start + slice_or_idx2[j].stop * slice_or_idx1[i].step,
|
||||
slice_or_idx1[i].step * slice_or_idx2[j].step
|
||||
))
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
def _to_range(transforms) -> tuple[slice | int, ...]:
|
||||
ret = ()
|
||||
for transform in transforms:
|
||||
# For now, assume only NDIndexer transforms.
|
||||
ret = _compose_slice_or_index(
|
||||
ret, tuple(_transform_slice_or_index(i) for i in transform.indices))
|
||||
return ret
|
||||
|
||||
def get(device_id, memory_space, buffer_id, transforms):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
try:
|
||||
transforms = jax.tree.map(int, transforms)
|
||||
except:
|
||||
raise ValueError('Advanced indexers are not supported on TPU')
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
return shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
].copy()
|
||||
|
||||
def store(device_id, memory_space, buffer_id, transforms, val):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
try:
|
||||
transforms = jax.tree.map(int, transforms)
|
||||
except:
|
||||
raise ValueError('Advanced indexers are not supported on TPU')
|
||||
val = np.array(val)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
if transforms:
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
] = val
|
||||
else:
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
|
||||
|
||||
def swap(device_id, memory_space, buffer_id, transforms, val):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
try:
|
||||
transforms = jax.tree.map(int, transforms)
|
||||
except:
|
||||
raise ValueError('Advanced indexers are not supported on TPU')
|
||||
val = np.array(val)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
result = shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
].copy()
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
] = val
|
||||
|
||||
return np.array(result)
|
||||
|
||||
def execute_dma(src, dst, send_sem, recv_sem):
|
||||
# NOTE: `src` is a list of arguments for `get` (device_id, memory_space,
|
||||
# buffer_id, transforms) and `dst` is a list of arguments for `store`
|
||||
# (dst_device_id, dst_memory_space, dst_id, dst_transforms).
|
||||
#
|
||||
# TODO(jburnim): Clean this up.
|
||||
|
||||
# Do the read.
|
||||
data = get(*src)
|
||||
data_size = data.itemsize * data.size
|
||||
|
||||
# Signal the send semaphore.
|
||||
if send_sem is not None:
|
||||
send_sem.signal(data_size, device_id=src[0])
|
||||
|
||||
# Do the write.
|
||||
store(*dst, data)
|
||||
|
||||
# Signal the receive semaphore.
|
||||
recv_sem.signal(data_size, device_id=dst[0])
|
||||
|
||||
def print_memory(device_id):
|
||||
device_id = tuple(map(int, device_id))
|
||||
if all(d == 0 for d in device_id):
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
print(shared_memory.mem)
|
||||
|
||||
def dma_start(device_id, src_memory_space, src_id, src_transforms,
|
||||
dst_memory_space, dst_id, dst_transforms,
|
||||
dst_sem,
|
||||
src_sem,
|
||||
dst_device_id):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
src_memory_space, src_id = int(src_memory_space), int(src_id)
|
||||
src_transforms = jax.tree.map(int, src_transforms)
|
||||
dst_memory_space, dst_id = int(dst_memory_space), int(dst_id)
|
||||
dst_transforms = jax.tree.map(int, dst_transforms)
|
||||
dst_sem = int(dst_sem)
|
||||
if src_sem is not None:
|
||||
src_sem = int(src_sem)
|
||||
if dst_device_id is not None:
|
||||
dst_device_id = tuple(int(x) for x in dst_device_id)
|
||||
else:
|
||||
dst_device_id = device_id
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
dst_sem = shared_memory.sem[dst_sem]
|
||||
if src_sem is not None:
|
||||
src_sem = shared_memory.sem[src_sem]
|
||||
|
||||
# For now, just execute the DMA immediately.
|
||||
# TODO(jburnim): Execute DMAs asynchronously.
|
||||
execute_dma(
|
||||
(device_id, src_memory_space, src_id, src_transforms),
|
||||
(dst_device_id, dst_memory_space, dst_id, dst_transforms),
|
||||
src_sem,
|
||||
dst_sem)
|
||||
|
||||
def dma_wait(device_id, sem, size):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
sem = int(sem)
|
||||
size = int(size)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
sem = shared_memory.sem[sem]
|
||||
sem.wait(size, device_id)
|
||||
|
||||
def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index):
|
||||
device_id = tuple(map(int, device_id))
|
||||
sem = int(sem)
|
||||
inc = int(inc)
|
||||
target_device_id = tuple(map(int, target_device_id))
|
||||
|
||||
if target_core_index is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
sem = shared_memory.sem[sem]
|
||||
sem.signal(inc, target_device_id)
|
||||
|
||||
def semaphore_wait(device_id, sem, value):
|
||||
device_id = tuple(map(int, device_id))
|
||||
sem = int(sem)
|
||||
value = int(value)
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
sem = shared_memory.sem[sem]
|
||||
sem.wait(value, device_id)
|
||||
|
||||
def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
|
||||
for transform in transforms:
|
||||
if transform is None:
|
||||
continue
|
||||
shape = transform.transform_shape(shape)
|
||||
dtype = transform.transform_dtype(dtype)
|
||||
return shape, dtype
|
||||
|
||||
@lu.cache
|
||||
def _to_jaxpr(flat_fun, in_avals):
|
||||
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts)
|
||||
return new_jaxpr
|
||||
|
||||
def _is_any(memory_space):
|
||||
return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or
|
||||
(memory_space == pallas_core.MemorySpace.ANY))
|
||||
|
||||
def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
env = {}
|
||||
|
||||
def read(var):
|
||||
if isinstance(var, jax_core.Literal):
|
||||
return var.val
|
||||
else:
|
||||
return env[var]
|
||||
|
||||
def write(var, value):
|
||||
env[var] = value
|
||||
|
||||
jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
|
||||
|
||||
# Get the mesh coordinates.
|
||||
device_coords = tuple(
|
||||
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
|
||||
# TODO(jburnim): Convert to a single integer device ID.
|
||||
# TODO(jburnim): Pass the device ID around, instead of re-fetching/computing
|
||||
# it for each sub-jaxpr.
|
||||
|
||||
# TODO(jburnim): Clean up and finish this evaluation loop. For example:
|
||||
# - Handle missing Pallas primitives, like masked_load.
|
||||
# - Replace the big if-statement with a dictionary of rules.
|
||||
# - Handle other higher-order primitives?
|
||||
# - Megacore.
|
||||
for eqn in jaxpr.eqns:
|
||||
prim = eqn.primitive
|
||||
invals = jax.util.safe_map(read, eqn.invars)
|
||||
|
||||
if prim is primitives.load_p:
|
||||
raise NotImplementedError()
|
||||
|
||||
elif prim is primitives.swap_p:
|
||||
raise NotImplementedError()
|
||||
|
||||
elif prim is lax.cond_p:
|
||||
def _make_branch(jaxpr):
|
||||
return lambda *args: _interpret_jaxpr(
|
||||
jaxpr, *args, compiler_params=compiler_params)
|
||||
out = lax.switch(
|
||||
invals[0],
|
||||
[_make_branch(branch_jaxpr.jaxpr)
|
||||
for branch_jaxpr in eqn.params['branches']],
|
||||
*invals[1:])
|
||||
|
||||
elif prim is lax.scan_p:
|
||||
consts, init_carry, xs = split_list(
|
||||
invals, [eqn.params['num_consts'], eqn.params['num_carry']])
|
||||
def _scan_body(c, a):
|
||||
return split_list(
|
||||
_interpret_jaxpr(eqn.params['jaxpr'].jaxpr, *consts, *c, *a,
|
||||
compiler_params=compiler_params),
|
||||
[eqn.params['num_carry']])
|
||||
carry, out = lax.scan(_scan_body, init_carry, xs=xs,
|
||||
length=eqn.params.get('length', None))
|
||||
out = carry + out
|
||||
|
||||
elif prim is lax.while_p:
|
||||
cond_consts, body_consts, init_vals = split_list(
|
||||
invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']])
|
||||
out = lax.while_loop(
|
||||
lambda args: _interpret_jaxpr(eqn.params['cond_jaxpr'].jaxpr,
|
||||
*cond_consts, *args,
|
||||
compiler_params=compiler_params)[0],
|
||||
lambda args: _interpret_jaxpr(eqn.params['body_jaxpr'].jaxpr,
|
||||
*body_consts, *args,
|
||||
compiler_params=compiler_params),
|
||||
init_vals)
|
||||
|
||||
elif prim is pjit.pjit_p:
|
||||
pjit_jaxpr = eqn.params['jaxpr']
|
||||
def f(*args):
|
||||
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
|
||||
compiler_params=compiler_params)
|
||||
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
|
||||
new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals)
|
||||
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
|
||||
|
||||
elif prim is primitives.run_scoped_p:
|
||||
# Allocate a buffer or semaphore for each element of
|
||||
# eqn.params['jaxpr'].invars .
|
||||
allocs = []
|
||||
for v in eqn.params['jaxpr'].invars:
|
||||
if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
|
||||
allocs.append(callback.io_callback(
|
||||
_allocate_semaphores,
|
||||
jax.ShapeDtypeStruct(v.aval.shape, jnp.int16),
|
||||
device_coords,
|
||||
v.aval.shape,
|
||||
ordered=True))
|
||||
else:
|
||||
allocs.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
||||
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
|
||||
ordered=True))
|
||||
|
||||
out = _interpret_jaxpr(eqn.params['jaxpr'], *invals, *allocs,
|
||||
compiler_params=compiler_params)
|
||||
|
||||
for a in allocs:
|
||||
if isinstance(a, tuple):
|
||||
callback.io_callback(
|
||||
_deallocate_buffer,
|
||||
None,
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
||||
a,
|
||||
ordered=True)
|
||||
else:
|
||||
# TODO(jburnim): Delete semaphores.
|
||||
# callback.io_callback(
|
||||
# _deallocate_semaphores,
|
||||
# None,
|
||||
# device_coords,
|
||||
# a,
|
||||
# ordered=True)
|
||||
pass
|
||||
|
||||
elif prim is state_primitives.get_p:
|
||||
out = callback.io_callback(
|
||||
get,
|
||||
eqn.outvars[0].aval,
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
||||
invals[0],
|
||||
jax.tree.unflatten(eqn.params['tree'], invals[1:]),
|
||||
ordered=True)
|
||||
|
||||
elif prim is state_primitives.swap_p:
|
||||
out = callback.io_callback(
|
||||
swap,
|
||||
eqn.outvars[0].aval,
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
||||
invals[0],
|
||||
jax.tree.unflatten(eqn.params['tree'], invals[2:]),
|
||||
invals[1],
|
||||
ordered=True)
|
||||
|
||||
elif prim is mosaic_primitives.dma_start_p:
|
||||
(src, src_transforms,
|
||||
dst, dst_transforms,
|
||||
dst_sem, dst_sem_transforms,
|
||||
src_sem, src_sem_transforms,
|
||||
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
|
||||
(orig_src_ref, _, orig_dst_ref, *_
|
||||
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
|
||||
callback.io_callback(
|
||||
dma_start,
|
||||
(),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space],
|
||||
src, src_transforms,
|
||||
TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space],
|
||||
dst, dst_transforms,
|
||||
state_discharge.transform_array(dst_sem, dst_sem_transforms),
|
||||
state_discharge.transform_array(src_sem, src_sem_transforms),
|
||||
device_id,
|
||||
ordered=True)
|
||||
out = []
|
||||
|
||||
elif prim is mosaic_primitives.dma_wait_p:
|
||||
(src, src_transforms,
|
||||
dst, dst_transforms,
|
||||
dst_sem, dst_sem_transforms,
|
||||
src_sem, src_sem_transforms,
|
||||
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
|
||||
read_shape, read_dtype = _compute_transformed_shape_and_dtype(
|
||||
eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms)
|
||||
callback.io_callback(
|
||||
dma_wait,
|
||||
(),
|
||||
device_coords,
|
||||
state_discharge.transform_array(dst_sem, dst_sem_transforms),
|
||||
math.prod(read_shape) * read_dtype.itemsize,
|
||||
ordered=True)
|
||||
out = []
|
||||
|
||||
elif prim is mosaic_primitives.get_barrier_semaphore_p:
|
||||
out = callback.io_callback(
|
||||
get_barrier_semaphore,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
compiler_params['mosaic']['collective_id'],
|
||||
ordered=True)
|
||||
|
||||
elif prim is mosaic_primitives.semaphore_signal_p:
|
||||
sem, sem_transforms, inc, device_id, core_index = (
|
||||
jax.tree.unflatten(eqn.params['args_tree'], invals))
|
||||
callback.io_callback(
|
||||
semaphore_signal,
|
||||
(),
|
||||
device_coords,
|
||||
state_discharge.transform_array(sem, sem_transforms),
|
||||
inc,
|
||||
device_id,
|
||||
core_index,
|
||||
ordered=True)
|
||||
out = []
|
||||
|
||||
elif prim is mosaic_primitives.semaphore_wait_p:
|
||||
sem, sem_transforms, value = (
|
||||
jax.tree.unflatten(eqn.params['args_tree'], invals))
|
||||
callback.io_callback(
|
||||
semaphore_wait,
|
||||
(),
|
||||
device_coords,
|
||||
state_discharge.transform_array(sem, sem_transforms),
|
||||
value,
|
||||
ordered=True)
|
||||
out = []
|
||||
|
||||
else:
|
||||
out = prim.bind(*invals, **eqn.params)
|
||||
|
||||
out = out if prim.multiple_results else [out]
|
||||
jax.util.safe_map(write, eqn.outvars, out)
|
||||
|
||||
return jax.util.safe_map(read, jaxpr.outvars)
|
||||
|
||||
def _initialize_output_vals(
|
||||
block_mappings_output: Iterable[BlockMapping],
|
||||
input_args, input_output_aliases) -> Sequence[jax.Array]:
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
output_vals = []
|
||||
for i, bm in enumerate(block_mappings_output):
|
||||
if i in oi_map:
|
||||
output_vals.append(input_args[oi_map[i]])
|
||||
else:
|
||||
output_vals.append(primitives.uninitialized_value(
|
||||
bm.array_shape_dtype.shape,
|
||||
bm.array_shape_dtype.dtype))
|
||||
return output_vals
|
||||
|
||||
def _compute_start_indices(block_mapping, loop_idx, *args):
|
||||
block_indices = (
|
||||
jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args))
|
||||
if isinstance(block_mapping.indexing_mode, pallas_core.Blocked):
|
||||
ret = tuple(i if b is pallas_core.mapped else b * i
|
||||
for b, i in zip(block_mapping.block_shape, block_indices))
|
||||
elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked):
|
||||
ret = block_indices
|
||||
else:
|
||||
raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}")
|
||||
return ret
|
||||
|
||||
def _get_next_indices(grid, indices):
|
||||
next_indices = []
|
||||
carry = True
|
||||
for dim_size, index in reversed(list(zip(grid, indices))):
|
||||
i = jnp.where(carry, index + 1, index)
|
||||
carry = dim_size == i
|
||||
next_indices.append(jnp.where(carry, 0, i))
|
||||
return tuple(reversed(next_indices))
|
||||
|
||||
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
||||
start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx)
|
||||
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
|
||||
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
|
||||
dtype=np.bool_)])
|
||||
return lax.squeeze(output, squeeze_dims)
|
||||
|
||||
def get_interpret_effects():
|
||||
return {callback._OrderedIOEffect}
|
||||
|
||||
def interpret_pallas_call(
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
del debug, cost_estimate, out_avals
|
||||
|
||||
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
|
||||
dynamic_grid_args, scalars, input_args = split_list( # type: ignore
|
||||
args,
|
||||
[grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands],
|
||||
)
|
||||
# TODO(jburnim): Support dynamic grid sizes?
|
||||
grid = grid_mapping.static_grid
|
||||
|
||||
device_coords = tuple(
|
||||
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
|
||||
|
||||
# Allocate buffers in HBM for outputs.
|
||||
io_alias_map = dict(input_output_aliases)
|
||||
output_buffer_ids = []
|
||||
output_vals = _initialize_output_vals(
|
||||
grid_mapping.block_mappings_output, args, input_output_aliases)
|
||||
for out_val in output_vals:
|
||||
output_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
out_val,
|
||||
ordered=True))
|
||||
|
||||
# Allocate buffers for all kernel arguments (e.g., scalars, inputs,
|
||||
# outputs, scratch).
|
||||
kernel_buffer_ids = []
|
||||
for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars):
|
||||
kernel_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM],
|
||||
val,
|
||||
ordered=True))
|
||||
for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]):
|
||||
output_idx = i - grid_mapping.num_inputs
|
||||
is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs)
|
||||
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
|
||||
kernel_buffer_ids.append(callback.io_callback(
|
||||
_allocate_semaphores,
|
||||
jax.ShapeDtypeStruct(var.aval.shape, jnp.int16),
|
||||
device_coords,
|
||||
var.aval.shape,
|
||||
ordered=True))
|
||||
elif is_output and _is_any(var.aval.memory_space):
|
||||
# Don't allocate a buffer -- use the already-allocated HBM output buffer.
|
||||
#
|
||||
# TODO(jburnim): For kernel args in HBM, check that block shape is the
|
||||
# same as for the corresponding pallas_call input, and that the index_map
|
||||
# is trivial.
|
||||
kernel_buffer_ids.append(output_buffer_ids[output_idx])
|
||||
elif (i < grid_mapping.num_inputs) and (i in io_alias_map):
|
||||
# Instead of allocating a new buffer, use the already-allocated
|
||||
# HBM output buffer.
|
||||
assert _is_any(var.aval.memory_space)
|
||||
kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]])
|
||||
else:
|
||||
kernel_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
|
||||
ordered=True))
|
||||
|
||||
num_inputs = grid_mapping.num_inputs
|
||||
_, input_ids, kernel_output_ids, _ = split_list(
|
||||
kernel_buffer_ids,
|
||||
[grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs])
|
||||
input_vars, output_vars = split_list(
|
||||
jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs])
|
||||
|
||||
# For kernel inputs that are in HBM, we populate the buffer once before
|
||||
# any kernel invocations.
|
||||
for buffer_id, var, val in zip(input_ids, input_vars, input_args):
|
||||
if not _is_any(var.aval.memory_space):
|
||||
continue
|
||||
if val.shape != var.aval.shape:
|
||||
# TODO(jburnim): Also check that the index_map is trivial.
|
||||
raise ValueError()
|
||||
callback.io_callback(
|
||||
store,
|
||||
(),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
buffer_id,
|
||||
(),
|
||||
val,
|
||||
ordered=True)
|
||||
|
||||
is_indexing_dim = [
|
||||
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
||||
for bm in grid_mapping.block_mappings
|
||||
]
|
||||
block_shapes = [
|
||||
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
||||
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
||||
]
|
||||
scalar_ids, in_out_ids, scratch_ids = split_list(
|
||||
kernel_buffer_ids,
|
||||
[grid_mapping.num_index_operands, len(grid_mapping.block_mappings)])
|
||||
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
|
||||
def body(carry):
|
||||
# The loop carry: (i, loop_idx) --
|
||||
# - i:int32 is the interation index
|
||||
# - loop_idx: tuple[int32] are the program ids for each grid axis
|
||||
i, loop_idx = carry
|
||||
|
||||
if grid_mapping.local_grid_env is not None:
|
||||
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
|
||||
else:
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.vmapped_dims
|
||||
)
|
||||
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
# Copy slices of the input to the kernel buffers.
|
||||
#
|
||||
# TODO(jburnim): Only copy slices when the index mapping has changed?
|
||||
start_indices = [_compute_start_indices(bm, loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
for j, var in enumerate(input_vars):
|
||||
if _is_any(var.aval.memory_space):
|
||||
continue
|
||||
sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j],
|
||||
input_args[j], is_indexing_dim[j])
|
||||
assert(sliced_val.shape == var.aval.shape)
|
||||
callback.io_callback(
|
||||
store,
|
||||
(),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
input_ids[j],
|
||||
(),
|
||||
sliced_val,
|
||||
ordered=True)
|
||||
|
||||
# Invoke the kernel.
|
||||
_interpret_jaxpr(jaxpr, *kernel_buffer_ids,
|
||||
compiler_params=compiler_params)
|
||||
|
||||
# Copy from the kernel buffers to slices of the output in HBM.
|
||||
#
|
||||
# TODO(jburnim): Only copy if the index mapping will change in the
|
||||
# next iteration (or if this is the last iteration)?
|
||||
for j, var in enumerate(output_vars):
|
||||
if _is_any(var.aval.memory_space):
|
||||
continue
|
||||
kernel_output_val = callback.io_callback(
|
||||
get,
|
||||
var.aval,
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
kernel_output_ids[j],
|
||||
(),
|
||||
ordered=True)
|
||||
transform = indexing.NDIndexer(
|
||||
indices=tuple(indexing.ds(st, sz)
|
||||
for st, sz in zip(start_indices[num_inputs + j],
|
||||
block_shapes[num_inputs + j])),
|
||||
shape=output_vals[j].shape,
|
||||
int_indexer_shape=())
|
||||
callback.io_callback(
|
||||
store,
|
||||
(),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
output_buffer_ids[j],
|
||||
(transform,),
|
||||
kernel_output_val,
|
||||
ordered=True)
|
||||
|
||||
return i + 1, _get_next_indices(grid, loop_idx)
|
||||
|
||||
# TODO(jburnim): Handle parallel grid dimensions + megacore.
|
||||
_ = lax.while_loop(
|
||||
lambda carry: carry[0] < num_iterations,
|
||||
body,
|
||||
(jnp.int32(0), (jnp.int32(0),) * len(grid))
|
||||
)
|
||||
|
||||
# Read the output from the allocated output buffers.
|
||||
ret = [
|
||||
callback.io_callback(
|
||||
get,
|
||||
val,
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
output_buffer_id,
|
||||
(),
|
||||
ordered=True)
|
||||
for val, output_buffer_id in zip(output_vals, output_buffer_ids)
|
||||
]
|
||||
|
||||
for buffer_id in output_buffer_ids:
|
||||
callback.io_callback(
|
||||
_deallocate_buffer,
|
||||
(),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
buffer_id,
|
||||
ordered=True)
|
||||
for buffer_id, var in zip(kernel_buffer_ids, jaxpr.invars):
|
||||
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
|
||||
pass
|
||||
else:
|
||||
callback.io_callback(
|
||||
_deallocate_buffer,
|
||||
(),
|
||||
device_coords,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
buffer_id,
|
||||
ordered=True)
|
||||
|
||||
return ret
|
@ -17,7 +17,9 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
from functools import partial, reduce
|
||||
import types
|
||||
from typing import Any, Literal
|
||||
|
||||
import jax
|
||||
@ -82,19 +84,31 @@ pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
|
||||
def _pallas_call_abstract_eval(
|
||||
*avals, out_avals: tuple[jax_core.AbstractValue, ...], **_
|
||||
*avals,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
interpret,
|
||||
backend,
|
||||
**params
|
||||
):
|
||||
del avals
|
||||
|
||||
if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams):
|
||||
# Report effects that will be introduced when running/lowering
|
||||
# mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call .
|
||||
effs = mosaic_tpu_interpret.get_interpret_effects()
|
||||
else:
|
||||
effs = jax_core.no_effects
|
||||
|
||||
# Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
|
||||
return [
|
||||
jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
|
||||
if isinstance(a, pallas_core.ShapedArrayWithMemorySpace)
|
||||
else a
|
||||
for a in out_avals
|
||||
]
|
||||
], effs
|
||||
|
||||
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
pallas_call_p.def_effectful_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
|
||||
def _pallas_call_jvp_rule(
|
||||
@ -1230,9 +1244,12 @@ def _pallas_call_lowering(
|
||||
if params['jaxpr'].constvars:
|
||||
raise ValueError('Cannot lower a pallas_call with constants.')
|
||||
if interpret:
|
||||
impl = partial(hlo_interpreter.pallas_call_hlo_interpret,
|
||||
backend=backend,
|
||||
**params)
|
||||
if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams):
|
||||
impl = partial(mosaic_tpu_interpret.interpret_pallas_call, **params)
|
||||
else:
|
||||
impl = partial(hlo_interpreter.pallas_call_hlo_interpret,
|
||||
backend=backend,
|
||||
**params)
|
||||
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
||||
|
||||
def cpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||
@ -1681,3 +1698,10 @@ try:
|
||||
from jax._src.pallas.mosaic import pallas_call_registration as mosaic_tpu_backend
|
||||
except ImportError:
|
||||
mosaic_tpu_backend = None # type: ignore
|
||||
|
||||
try:
|
||||
from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret
|
||||
except ImportError:
|
||||
mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore
|
||||
TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)),
|
||||
)
|
||||
|
@ -426,6 +426,32 @@ jax_multiplatform_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_interpret_test",
|
||||
srcs = [
|
||||
"tpu_pallas_interpret_test.py",
|
||||
],
|
||||
disable_configs = ["cpu_shardy"],
|
||||
enable_backends = ["cpu"],
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_tpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_interpret_distributed_test",
|
||||
srcs = [
|
||||
"tpu_pallas_interpret_distributed_test.py",
|
||||
],
|
||||
disable_configs = ["cpu_shardy"],
|
||||
enable_backends = ["cpu"],
|
||||
deps = [
|
||||
"//third_party/py/jax:pallas",
|
||||
"//third_party/py/jax:pallas_tpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_paged_attention_kernel_test",
|
||||
srcs = ["tpu_paged_attention_kernel_test.py"],
|
||||
|
995
tests/pallas/tpu_pallas_interpret_distributed_test.py
Normal file
995
tests/pallas/tpu_pallas_interpret_distributed_test.py
Normal file
@ -0,0 +1,995 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for TPU-specific interpret mode.
|
||||
|
||||
To work around https://github.com/jax-ml/jax/issues/25671 , this file
|
||||
contains only tests that use shard_map.
|
||||
"""
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.pallas.mosaic.interpret as mosaic_interpret
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
|
||||
|
||||
class InterpretDistributedTest(jtu.JaxTestCase):
|
||||
|
||||
def test_right_permute_example(self):
|
||||
num_devices = jax.device_count()
|
||||
if num_devices < 4:
|
||||
self.skipTest(f'requires at least 4 devices, found {num_devices}')
|
||||
partition = P(None, 'x')
|
||||
mesh = jax.make_mesh((num_devices,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, partition)
|
||||
|
||||
# Create an input array that shards the last dimension across
|
||||
# all devices.
|
||||
input_arr = jax.random.uniform(
|
||||
jax.random.key(0), (8, 128 * num_devices), dtype=jnp.float32)
|
||||
input_arr = jax.device_put(input_arr, sharding)
|
||||
|
||||
def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
|
||||
my_id = lax.axis_index('x')
|
||||
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
|
||||
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
|
||||
|
||||
barrier_sem = pltpu.get_barrier_semaphore()
|
||||
def _body(ijk):
|
||||
i, (j, k) = ijk
|
||||
lax.cond(
|
||||
(i == 0) | (j == 0),
|
||||
lambda: pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH),
|
||||
lambda: pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH))
|
||||
return (i + 1, (j + 1, k + 1))
|
||||
lax.while_loop(lambda ijk: ijk[0] < 2, _body, (0, (0, 0)))
|
||||
pltpu.semaphore_wait(barrier_sem, 2)
|
||||
|
||||
def _body2(i, a):
|
||||
remote_copy_op = pltpu.make_async_remote_copy(
|
||||
src_ref=input_ref,
|
||||
dst_ref=output_ref,
|
||||
send_sem=send_sem,
|
||||
recv_sem=recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
remote_copy_op.start()
|
||||
remote_copy_op.wait()
|
||||
|
||||
return i + 1, a + 1
|
||||
_ = lax.scan(_body2, 0, jnp.arange(4.0), unroll=2)
|
||||
|
||||
out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
|
||||
grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
scratch_shapes=(
|
||||
# We allocate DMA semaphores in scratch memory.
|
||||
[pltpu.SemaphoreType.DMA] * 2
|
||||
),
|
||||
)
|
||||
right_permute = pl.pallas_call(
|
||||
right_permute_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=13),
|
||||
interpret=mosaic_interpret.TPUInterpretParams(),
|
||||
)
|
||||
# Wrap the kernel within a shard_map to call.
|
||||
pallas_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
right_permute,
|
||||
mesh=mesh,
|
||||
in_specs=partition,
|
||||
out_specs=partition,
|
||||
check_rep=False,
|
||||
)
|
||||
)(input_arr)
|
||||
|
||||
# Compare Pallas result to XLA shard_map result.
|
||||
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))
|
||||
xla_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
lambda x: lax.ppermute(x, 'x', perm),
|
||||
mesh=mesh, in_specs=partition, out_specs=partition)
|
||||
)(input_arr)
|
||||
|
||||
np.testing.assert_allclose(xla_result, pallas_result)
|
||||
|
||||
|
||||
def test_all_gather_example(self):
|
||||
num_devices = jax.device_count()
|
||||
if num_devices < 4:
|
||||
self.skipTest(f'requires at least 4 devices, found {num_devices}')
|
||||
partition = P('x', None)
|
||||
mesh = jax.make_mesh((num_devices,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, partition)
|
||||
|
||||
# Create an input array that shards the first dimension across
|
||||
# all devices.
|
||||
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
|
||||
input_arr = jax.device_put(input_arr, sharding)
|
||||
|
||||
def all_gather_kernel(input_ref,
|
||||
output_ref,
|
||||
local_copy_sem,
|
||||
send_sem,
|
||||
recv_sems):
|
||||
outer_step = pl.program_id(0)
|
||||
my_id = lax.axis_index('x')
|
||||
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
|
||||
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
|
||||
copy_slot = my_id - outer_step
|
||||
copy_slot = lax.rem(copy_slot + num_devices, jnp.int32(num_devices))
|
||||
|
||||
@pl.when(outer_step == 0)
|
||||
def _():
|
||||
# Barrier with both neighbors at the start, since we will be
|
||||
# communicating with both.
|
||||
barrier_sem = pltpu.get_barrier_semaphore()
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_wait(barrier_sem, 2)
|
||||
|
||||
local_copy_op = pltpu.make_async_copy(
|
||||
src_ref=input_ref,
|
||||
dst_ref=output_ref.at[my_id],
|
||||
sem=local_copy_sem,
|
||||
)
|
||||
local_copy_op.start()
|
||||
local_copy_op.wait()
|
||||
|
||||
# Copy to our right neighbor.
|
||||
# Note that we will also be receiving data from our left neighbor,
|
||||
# but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
|
||||
# that the indices do not need to be symmetric between remote DMAs.
|
||||
remote_copy_op = pltpu.make_async_remote_copy(
|
||||
src_ref=output_ref.at[copy_slot],
|
||||
dst_ref=output_ref.at[copy_slot],
|
||||
send_sem=send_sem,
|
||||
recv_sem=recv_sems.at[outer_step],
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
remote_copy_op.start()
|
||||
remote_copy_op.wait()
|
||||
|
||||
out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
|
||||
grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
scratch_shapes=(
|
||||
# DMA semaphores are allocated in scratch memory.
|
||||
# We allocated one semaphore for a local HBM-VMEM copy,
|
||||
# and one for the remote send semaphore.
|
||||
[pltpu.SemaphoreType.DMA] * 2
|
||||
# We additionally allocate one receive semaphore per device.
|
||||
# This is to avoid situations where we have multiple
|
||||
# DMAs in flight, as we do not want to share a receive
|
||||
# semaphore between the DMAs.
|
||||
+ [pltpu.SemaphoreType.DMA((num_devices-1,))]
|
||||
),
|
||||
grid=(num_devices-1,)
|
||||
)
|
||||
|
||||
all_gather = pl.pallas_call(
|
||||
all_gather_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
interpret=mosaic_interpret.TPUInterpretParams(),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
)
|
||||
|
||||
# Wrap the kernel within a shard_map to call.
|
||||
pallas_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
all_gather,
|
||||
mesh=mesh,
|
||||
in_specs=partition,
|
||||
out_specs=partition,
|
||||
check_rep=False
|
||||
)
|
||||
)(input_arr)
|
||||
|
||||
# Compare Pallas result to XLA shard_map result.
|
||||
xla_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
lambda x: lax.all_gather(x, 'x'),
|
||||
mesh=mesh, in_specs=partition, out_specs=partition
|
||||
)
|
||||
)(input_arr)
|
||||
|
||||
np.testing.assert_allclose(xla_result, pallas_result)
|
||||
|
||||
def test_all_reduce_sum_example(self):
|
||||
num_devices = jax.device_count()
|
||||
if num_devices < 4:
|
||||
self.skipTest(f'requires at least 4 devices, found {num_devices}')
|
||||
partition = P(None, 'x')
|
||||
mesh = jax.make_mesh((num_devices,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, partition)
|
||||
|
||||
input_arr = jax.random.uniform(
|
||||
jax.random.key(0), shape=(8, 128 * num_devices))
|
||||
input_arr = jax.device_put(input_arr, sharding)
|
||||
|
||||
def all_reduce_kernel(
|
||||
x_ref,
|
||||
o_ref,
|
||||
hbm_scratch,
|
||||
copy_sem,
|
||||
remote_recv_sem,
|
||||
remote_send_sem,
|
||||
capacity_sem,
|
||||
receive_scratch,
|
||||
):
|
||||
outer_step = pl.program_id(0)
|
||||
working_slot = lax.rem(outer_step, jnp.int32(2))
|
||||
receiving_slot = 1 - working_slot
|
||||
|
||||
my_id = lax.axis_index('x')
|
||||
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
|
||||
left_neighbor = lax.rem(my_id - 1 + num_devices, jnp.int32(num_devices))
|
||||
|
||||
@pl.when(outer_step == 0)
|
||||
def _():
|
||||
# Barrier with both neighbors at the start, since we will be
|
||||
# communicating with both.
|
||||
barrier_sem = pltpu.get_barrier_semaphore()
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_wait(barrier_sem, 2)
|
||||
|
||||
# Initialize o_ref, acc_scratch, and hbm_scratch.
|
||||
o_ref[...] = jnp.zeros_like(o_ref)
|
||||
receive_scratch[...] = jnp.zeros_like(receive_scratch)
|
||||
initial_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=x_ref,
|
||||
dst_ref=hbm_scratch.at[working_slot],
|
||||
send_sem=remote_send_sem,
|
||||
recv_sem=remote_recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
initial_copy.start()
|
||||
initial_copy.wait()
|
||||
|
||||
# Signal to our left neighbor that we are ready to receive.
|
||||
# Without this signal, our left neighbor can be >=1 iteration ahead,
|
||||
# meaning it could write into our working slot.
|
||||
pltpu.semaphore_signal(
|
||||
capacity_sem,
|
||||
inc=1,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
# Copy the partial result our left neighbor sent to us into VMEM for
|
||||
# computation.
|
||||
local_copy = pltpu.make_async_copy(
|
||||
src_ref=hbm_scratch.at[working_slot],
|
||||
dst_ref=receive_scratch,
|
||||
sem=copy_sem,
|
||||
)
|
||||
local_copy.start()
|
||||
|
||||
# Block until our right neighbor is ready to receive.
|
||||
pltpu.semaphore_wait(capacity_sem, 1)
|
||||
# Pass the value to our right neighbor.
|
||||
remote_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=hbm_scratch.at[working_slot],
|
||||
dst_ref=hbm_scratch.at[receiving_slot],
|
||||
send_sem=remote_send_sem,
|
||||
recv_sem=remote_recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
remote_copy.start()
|
||||
# Finish local copy and accumulate while remote_copy is happening.
|
||||
local_copy.wait()
|
||||
o_ref[...] += receive_scratch[...]
|
||||
# Block until remote copy finishes.
|
||||
remote_copy.wait()
|
||||
|
||||
out_shape = (
|
||||
jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
# We allocate the double-buffer as a Pallas output so that it is
|
||||
# resident in HBM.
|
||||
jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch
|
||||
)
|
||||
|
||||
grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
# Our input lives in VMEM
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
],
|
||||
out_specs=[
|
||||
# Our output lives in VMEM
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
# Our double-buffer lives in HBM
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
],
|
||||
grid=(num_devices,),
|
||||
scratch_shapes=(
|
||||
[pltpu.SemaphoreType.DMA] * 3
|
||||
+ [pltpu.SemaphoreType.REGULAR] # capacity_sem
|
||||
+ [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch
|
||||
),
|
||||
)
|
||||
|
||||
kernel = pl.pallas_call(
|
||||
all_reduce_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
interpret=mosaic_interpret.TPUInterpretParams(),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
)
|
||||
|
||||
pallas_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
kernel,
|
||||
mesh=mesh,
|
||||
in_specs=partition,
|
||||
out_specs=partition,
|
||||
check_rep=False,
|
||||
)
|
||||
)(input_arr)
|
||||
pallas_result = jax.block_until_ready(pallas_result)[0]
|
||||
|
||||
def lax_sum(x):
|
||||
return lax.psum(x, 'x')
|
||||
|
||||
xla_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
|
||||
)
|
||||
)(input_arr)
|
||||
|
||||
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
|
||||
|
||||
|
||||
def test_reduce_scatter_sum_example(self):
|
||||
num_devices = jax.device_count()
|
||||
if num_devices < 4:
|
||||
self.skipTest(f'requires at least 4 devices, found {num_devices}')
|
||||
partition = P(None, 'x')
|
||||
mesh = jax.make_mesh((num_devices,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, partition)
|
||||
|
||||
# We need a block size of (16, 128) to ensure that a half-slice is at least
|
||||
# of size (8, 128), which is the size of a VREG. This makes tiling easier
|
||||
# for the compiler.
|
||||
block_size = (16, 128)
|
||||
input_arr = jax.random.uniform(
|
||||
jax.random.key(0),
|
||||
shape=(block_size[0] * num_devices, block_size[1] * num_devices),
|
||||
dtype=jnp.float32,
|
||||
)
|
||||
input_arr = jax.device_put(input_arr, sharding)
|
||||
|
||||
LEFT = 0
|
||||
RIGHT = 1
|
||||
|
||||
def mod(x, n):
|
||||
return lax.rem(x + n, n)
|
||||
|
||||
def signal(left_or_right, semaphore):
|
||||
my_id = lax.axis_index('x')
|
||||
if left_or_right == LEFT:
|
||||
neighbor = mod(my_id - 1, jnp.int32(num_devices))
|
||||
else:
|
||||
neighbor = mod(my_id + 1, jnp.int32(num_devices))
|
||||
pltpu.semaphore_signal(
|
||||
semaphore,
|
||||
inc=1,
|
||||
device_id=(neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
def reduce_scatter_kernel(
|
||||
x_ref,
|
||||
o_ref,
|
||||
hbm_scratch,
|
||||
local_copy_sem,
|
||||
left_recv_sem,
|
||||
left_send_sem,
|
||||
right_recv_sem,
|
||||
right_send_sem,
|
||||
left_capacity_sem,
|
||||
right_capacity_sem,
|
||||
accum_scratch,
|
||||
):
|
||||
outer_step = pl.program_id(0)
|
||||
phase = pl.program_id(1)
|
||||
is_start = jnp.logical_and(outer_step == 0, phase == 0)
|
||||
last_iteration = outer_step == pl.num_programs(0) - 1
|
||||
|
||||
working_slot = lax.rem(outer_step, jnp.int32(2))
|
||||
receiving_slot = 1 - working_slot
|
||||
my_id = lax.axis_index('x')
|
||||
right_neighbor = mod(my_id + 1, jnp.int32(num_devices))
|
||||
left_neighbor = mod(my_id - 1, jnp.int32(num_devices))
|
||||
|
||||
left_copy_device = mod(my_id + outer_step + 1, jnp.int32(num_devices))
|
||||
right_copy_device = mod(my_id - outer_step - 1, jnp.int32(num_devices))
|
||||
# Slices can be specified using pl.ds(start, size)
|
||||
left_copy_slice = pl.ds(0, block_size[0] // 2)
|
||||
right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
|
||||
current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)
|
||||
|
||||
@pl.when(is_start)
|
||||
def _():
|
||||
# Barrier with both neighbors at the start, since we will be
|
||||
# communicating with both.
|
||||
barrier_sem = pltpu.get_barrier_semaphore()
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_wait(barrier_sem, 2)
|
||||
|
||||
initial_left_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=x_ref.at[my_id, left_copy_slice],
|
||||
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
|
||||
send_sem=left_send_sem,
|
||||
recv_sem=left_recv_sem,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
initial_right_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=x_ref.at[my_id, right_copy_slice],
|
||||
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
|
||||
send_sem=right_send_sem,
|
||||
recv_sem=right_recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
left_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
|
||||
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
|
||||
send_sem=left_send_sem,
|
||||
recv_sem=left_recv_sem,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
right_copy = pltpu.make_async_remote_copy(
|
||||
# Note: Right copy is flipped with regards to slots since we are copying
|
||||
# to the next outer_step iteration.
|
||||
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
|
||||
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
|
||||
send_sem=right_send_sem,
|
||||
recv_sem=right_recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
# --- Prologue ---
|
||||
@pl.when(is_start)
|
||||
def _():
|
||||
# Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
|
||||
o_ref[...] = jnp.zeros_like(o_ref[...])
|
||||
accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
|
||||
|
||||
initial_left_copy.start()
|
||||
initial_left_copy.wait()
|
||||
initial_right_copy.start()
|
||||
|
||||
# We tell our left neighbor that it is allowed to send to the right.
|
||||
# (and vice versa for right neighbor)
|
||||
signal(LEFT, right_capacity_sem)
|
||||
signal(RIGHT, left_capacity_sem)
|
||||
|
||||
# --- Body ---
|
||||
# At the beginning of our kernel body, we start a DMA which copies
|
||||
# the result we computed in the previous phase to our neighbor.
|
||||
# This allows us to overlap the communication of sending our previous phase
|
||||
# with the computation for the current phase.
|
||||
@pl.when(~is_start)
|
||||
def _():
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
# We block here until our right neighbor tells use we can send to
|
||||
# the right.
|
||||
pltpu.semaphore_wait(right_capacity_sem, 1)
|
||||
right_copy.start()
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
# We block here until our left neighbor tells use we can send to
|
||||
# the left.
|
||||
pltpu.semaphore_wait(left_capacity_sem, 1)
|
||||
left_copy.start()
|
||||
|
||||
local_copy = pltpu.make_async_copy(
|
||||
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
|
||||
dst_ref=accum_scratch,
|
||||
sem=local_copy_sem,
|
||||
)
|
||||
local_copy.start()
|
||||
local_copy.wait()
|
||||
|
||||
@pl.when(~last_iteration)
|
||||
def _():
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
|
||||
|
||||
local_copy = pltpu.make_async_copy(
|
||||
src_ref=accum_scratch,
|
||||
dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
|
||||
sem=local_copy_sem,
|
||||
)
|
||||
local_copy.start()
|
||||
local_copy.wait()
|
||||
|
||||
@pl.when(is_start)
|
||||
def _():
|
||||
initial_right_copy.wait()
|
||||
|
||||
# At the end of our kernel body, we wait on the DMA of the previous phase
|
||||
# to make sure the results are ready for the next phase.
|
||||
@pl.when(~is_start)
|
||||
def _():
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
right_copy.wait()
|
||||
signal(LEFT, right_capacity_sem)
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
left_copy.wait()
|
||||
signal(RIGHT, left_capacity_sem)
|
||||
|
||||
# --- Epilogue ---
|
||||
# Store result on last iteration.
|
||||
@pl.when(last_iteration)
|
||||
def _():
|
||||
# Clean up semaphores so that they exit with a value of 0.
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
o_ref[left_copy_slice, ...] = accum_scratch[...]
|
||||
pltpu.semaphore_wait(right_capacity_sem, 1)
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
o_ref[right_copy_slice, ...] = accum_scratch[...]
|
||||
pltpu.semaphore_wait(left_capacity_sem, 1)
|
||||
|
||||
out_shape = (
|
||||
jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output
|
||||
# Shape: [working/recv, block[0], block[1]]
|
||||
jax.ShapeDtypeStruct(
|
||||
(2, block_size[0], block_size[1]), jnp.float32
|
||||
), # hbm_scratch
|
||||
)
|
||||
|
||||
grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
],
|
||||
grid=(num_devices, 2),
|
||||
scratch_shapes=(
|
||||
[pltpu.SemaphoreType.DMA] * 5
|
||||
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
|
||||
+ [
|
||||
pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
|
||||
] # accum_scratch
|
||||
),
|
||||
)
|
||||
|
||||
def pallas_reduce_scatter(input_arr):
|
||||
input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
|
||||
return pl.pallas_call(
|
||||
reduce_scatter_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
interpret=mosaic_interpret.TPUInterpretParams(),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=7),
|
||||
)(input_arr)[0]
|
||||
|
||||
pallas_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
pallas_reduce_scatter,
|
||||
mesh=mesh,
|
||||
in_specs=P(None, 'x'),
|
||||
out_specs=P('x', None),
|
||||
check_rep=False,
|
||||
)
|
||||
)(input_arr)
|
||||
pallas_result = jax.block_until_ready(pallas_result)
|
||||
|
||||
# Compare our result to XLA.
|
||||
def lax_reduce_sum_scatter(x):
|
||||
x = x.reshape(num_devices, block_size[0], block_size[1])
|
||||
return lax.psum_scatter(x, 'x')
|
||||
|
||||
xla_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
lax_reduce_sum_scatter,
|
||||
mesh=mesh,
|
||||
in_specs=P(None, 'x'),
|
||||
out_specs=P('x', None),
|
||||
)
|
||||
)(input_arr)
|
||||
|
||||
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
|
||||
|
||||
def test_reduce_scatter_sum_with_emit_pipeline_example(self):
|
||||
self.skipTest('requires a patched pallas.emit_pipeline to specify/fake '
|
||||
'the TPU generation')
|
||||
if jax.config.jax_enable_x64:
|
||||
self.skipTest('pallas.emit_pipeline + x64 is not currently supported')
|
||||
num_devices = jax.device_count()
|
||||
if num_devices < 4:
|
||||
self.skipTest(f'requires at least 4 devices, found {num_devices}')
|
||||
partition = P(None, 'x')
|
||||
mesh = jax.make_mesh((num_devices,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, partition)
|
||||
|
||||
# We pick a large outer kernel block size that we do not want to place
|
||||
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
|
||||
# principle this can be much larger.
|
||||
outer_block_size = (512, 512)
|
||||
# We pick a smaller VMEM block size for the inner kernel.
|
||||
inner_block_size = (128, 128)
|
||||
input_arr = jax.random.uniform(
|
||||
jax.random.key(0),
|
||||
shape=(
|
||||
outer_block_size[0] * num_devices,
|
||||
outer_block_size[1] * num_devices,
|
||||
),
|
||||
)
|
||||
input_arr = jax.device_put(input_arr, sharding)
|
||||
|
||||
inner_grid = (
|
||||
outer_block_size[0] // inner_block_size[0] // 2,
|
||||
outer_block_size[1] // inner_block_size[1],
|
||||
)
|
||||
inner_block_spec = pl.BlockSpec(
|
||||
index_map=lambda i, j: (i, j),
|
||||
block_shape=inner_block_size,
|
||||
memory_space=pltpu.TPUMemorySpace.ANY,
|
||||
)
|
||||
|
||||
LEFT = 0
|
||||
RIGHT = 1
|
||||
|
||||
def mod(x, n):
|
||||
return lax.rem(x + n, n)
|
||||
|
||||
def signal(left_or_right, semaphore):
|
||||
my_id = lax.axis_index('x')
|
||||
if left_or_right == LEFT:
|
||||
neighbor = mod(my_id - 1, num_devices)
|
||||
else:
|
||||
neighbor = mod(my_id + 1, num_devices)
|
||||
pltpu.semaphore_signal(
|
||||
semaphore,
|
||||
inc=1,
|
||||
device_id=(neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
def reduce_scatter_kernel(
|
||||
x_ref,
|
||||
o_ref,
|
||||
hbm_scratch,
|
||||
left_recv_sem,
|
||||
left_send_sem,
|
||||
copy_sem,
|
||||
right_recv_sem,
|
||||
right_send_sem,
|
||||
left_capacity_sem,
|
||||
right_capacity_sem,
|
||||
):
|
||||
outer_step = pl.program_id(0)
|
||||
phase = pl.program_id(1)
|
||||
is_start = jnp.logical_and(outer_step == 0, phase == 0)
|
||||
last_iteration = outer_step == pl.num_programs(0) - 1
|
||||
|
||||
working_slot = lax.rem(outer_step, 2)
|
||||
receiving_slot = 1 - working_slot
|
||||
my_id = lax.axis_index('x')
|
||||
right_neighbor = mod(my_id + 1, num_devices)
|
||||
left_neighbor = mod(my_id - 1, num_devices)
|
||||
|
||||
left_copy_device = mod(my_id + outer_step + 1, num_devices)
|
||||
right_copy_device = mod(my_id - outer_step - 1, num_devices)
|
||||
left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
|
||||
right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
|
||||
current_phase_slice = pl.ds(
|
||||
phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
|
||||
)
|
||||
|
||||
initial_left_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=x_ref.at[my_id, left_copy_slice],
|
||||
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
|
||||
send_sem=left_send_sem,
|
||||
recv_sem=left_recv_sem,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
initial_right_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=x_ref.at[my_id, right_copy_slice],
|
||||
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
|
||||
send_sem=right_send_sem,
|
||||
recv_sem=right_recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
left_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
|
||||
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
|
||||
send_sem=left_send_sem,
|
||||
recv_sem=left_recv_sem,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
right_copy = pltpu.make_async_remote_copy(
|
||||
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
|
||||
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
|
||||
send_sem=right_send_sem,
|
||||
recv_sem=right_recv_sem,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
|
||||
# --- Prologue ---
|
||||
@pl.when(is_start)
|
||||
def _():
|
||||
# Barrier with both neighbors at the start, since we will be
|
||||
# communicating with both.
|
||||
barrier_sem = pltpu.get_barrier_semaphore()
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(left_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_signal(
|
||||
barrier_sem,
|
||||
inc=1,
|
||||
device_id=(right_neighbor,),
|
||||
device_id_type=pltpu.DeviceIdType.MESH,
|
||||
)
|
||||
pltpu.semaphore_wait(barrier_sem, 2)
|
||||
|
||||
initial_left_copy.start()
|
||||
initial_left_copy.wait()
|
||||
initial_right_copy.start()
|
||||
|
||||
# We tell our left neighbor that it is allowed to send to the right.
|
||||
# (and vice versa for right neighbor)
|
||||
signal(LEFT, right_capacity_sem)
|
||||
signal(RIGHT, left_capacity_sem)
|
||||
|
||||
@pl.when(~is_start)
|
||||
def _():
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
# We block here until our right neighbor tells use we can send to
|
||||
# the right.
|
||||
pltpu.semaphore_wait(right_capacity_sem, 1)
|
||||
right_copy.start()
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
# We block here until our left neighbor tells use we can send to
|
||||
# the left.
|
||||
pltpu.semaphore_wait(left_capacity_sem, 1)
|
||||
left_copy.start()
|
||||
|
||||
# --- Body ---
|
||||
def inner_kernel(input_ref, accum_ref):
|
||||
# We do not explicitly use += because we set should_accumulate_out=True.
|
||||
accum_ref[...] = input_ref[...]
|
||||
|
||||
accum_pipeline = pltpu.emit_pipeline(
|
||||
inner_kernel,
|
||||
in_specs=[inner_block_spec],
|
||||
out_specs=inner_block_spec,
|
||||
should_accumulate_out=True,
|
||||
grid=inner_grid,
|
||||
)
|
||||
|
||||
@pl.when(~last_iteration)
|
||||
def _():
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
accum_pipeline(
|
||||
x_ref.at[left_copy_device, left_copy_slice],
|
||||
hbm_scratch.at[working_slot, left_copy_slice],
|
||||
)
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
accum_pipeline(
|
||||
x_ref.at[right_copy_device, right_copy_slice],
|
||||
hbm_scratch.at[working_slot, right_copy_slice],
|
||||
)
|
||||
|
||||
# --- Epilogue ---
|
||||
@pl.when(is_start)
|
||||
def _():
|
||||
initial_right_copy.wait()
|
||||
|
||||
@pl.when(~is_start)
|
||||
def _():
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
right_copy.wait()
|
||||
signal(LEFT, right_capacity_sem)
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
left_copy.wait()
|
||||
signal(RIGHT, left_capacity_sem)
|
||||
|
||||
# Store result on last iteration.
|
||||
@pl.when(last_iteration)
|
||||
def _():
|
||||
output_copy = pltpu.make_async_copy(
|
||||
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
|
||||
dst_ref=o_ref.at[current_phase_slice],
|
||||
sem=copy_sem,
|
||||
)
|
||||
output_copy.start()
|
||||
output_copy.wait()
|
||||
|
||||
# Clean up semaphores so that they exit with a value of 0.
|
||||
@pl.when(phase == LEFT)
|
||||
def _():
|
||||
pltpu.semaphore_wait(right_capacity_sem, 1)
|
||||
|
||||
@pl.when(phase == RIGHT)
|
||||
def _():
|
||||
pltpu.semaphore_wait(left_capacity_sem, 1)
|
||||
|
||||
|
||||
out_shape = (
|
||||
jax.ShapeDtypeStruct(
|
||||
(outer_block_size[0], outer_block_size[1]), jnp.float32
|
||||
),
|
||||
# Shape: [working/recv, block[0], block[1]]
|
||||
jax.ShapeDtypeStruct(
|
||||
(2, outer_block_size[0], outer_block_size[1]), jnp.float32
|
||||
), # hbm_scratch
|
||||
)
|
||||
|
||||
grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
],
|
||||
grid=(num_devices, 2),
|
||||
scratch_shapes=(
|
||||
[pltpu.SemaphoreType.DMA] * 5
|
||||
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
|
||||
),
|
||||
)
|
||||
|
||||
def pallas_reduce_scatter(input_arr):
|
||||
input_arr = input_arr.reshape(
|
||||
num_devices, outer_block_size[0], outer_block_size[1]
|
||||
)
|
||||
return pl.pallas_call(
|
||||
reduce_scatter_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
interpret=mosaic_interpret.TPUInterpretParams(),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=19),
|
||||
)(input_arr)[0]
|
||||
|
||||
pallas_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
pallas_reduce_scatter,
|
||||
mesh=mesh,
|
||||
in_specs=P(None, 'x'),
|
||||
out_specs=P('x', None),
|
||||
check_rep=False,
|
||||
)
|
||||
)(input_arr)
|
||||
pallas_result = jax.block_until_ready(pallas_result)
|
||||
|
||||
def lax_reduce_sum_scatter(x):
|
||||
x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
|
||||
return lax.psum_scatter(x, 'x')
|
||||
|
||||
xla_result = jax.jit(
|
||||
shard_map.shard_map(
|
||||
lax_reduce_sum_scatter,
|
||||
mesh=mesh,
|
||||
in_specs=P(None, 'x'),
|
||||
out_specs=P('x', None),
|
||||
)
|
||||
)(input_arr)
|
||||
|
||||
np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
67
tests/pallas/tpu_pallas_interpret_test.py
Normal file
67
tests/pallas/tpu_pallas_interpret_test.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for TPU-specific interpret mode.
|
||||
|
||||
To work around https://github.com/jax-ml/jax/issues/25671 , this file
|
||||
contains only tests that do not use shard_map.
|
||||
"""
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.pallas.mosaic.interpret as mosaic_interpret
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class InterpretTest(jtu.JaxTestCase):
|
||||
|
||||
def test_matmul_example(self):
|
||||
num_devices = jax.device_count()
|
||||
if num_devices > 1:
|
||||
# Workaround for https://github.com/jax-ml/jax/issues/25671
|
||||
self.skipTest(f'requires 1 device, found {num_devices}')
|
||||
|
||||
def matmul_kernel(x_ref, y_ref, z_ref):
|
||||
z_ref[...] = x_ref[...] @ y_ref[...]
|
||||
|
||||
@jax.jit
|
||||
def matmul(x: jax.Array, y: jax.Array):
|
||||
return pl.pallas_call(
|
||||
matmul_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
|
||||
grid=(2, 2),
|
||||
in_specs=[
|
||||
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
|
||||
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
|
||||
),
|
||||
interpret=mosaic_interpret.TPUInterpretParams(),
|
||||
)(x, y)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
x = jax.random.normal(k1, (1024, 1024))
|
||||
y = jax.random.normal(k2, (1024, 1024))
|
||||
z = matmul(x, y)
|
||||
np.testing.assert_allclose(z, x @ y, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user