[Mosaic] Several fixes/improvements for the new TPU interpret mode.

- Checks bounds for reads and writes to shared memory.
 - Pads kernel arguments when necessary.
 - Fix support for input-output aliasing.
 - Fix handling of vmap'ed dimensions.
 - Supports un-masked `pl.load` and masked or un-masked `pl.swap`.
 - Switch to using single integer device IDs instead of tuples.
 - Better error messages for unsupported primitives: `for_p`, `atomic_rmw_p`, and `atomic_cas_p` .

PiperOrigin-RevId: 727301519
This commit is contained in:
Jacob Burnim 2025-02-15 08:35:20 -08:00 committed by jax authors
parent eaceac3bf9
commit 962eb41933

View File

@ -15,7 +15,7 @@
import collections
from collections.abc import Iterable, Sequence
import dataclasses
from functools import reduce
import functools
import math
import threading
from typing import Any
@ -24,6 +24,7 @@ import jax
from jax import lax
from jax._src import callback
from jax._src import core as jax_core
from jax._src.lax.control_flow import for_loop
from jax._src import linear_util as lu
from jax._src.pallas.mosaic import primitives as mosaic_primitives
from jax._src.pallas.mosaic import core as mosaic_core
@ -72,13 +73,11 @@ class Semaphore:
self.counts = collections.defaultdict(int)
def signal(self, inc, device_id):
device_id = tuple(int(x) for x in device_id)
with self.cv:
self.counts[device_id] += inc
self.cv.notify_all()
def wait(self, value, device_id):
device_id = tuple(int(x) for x in device_id)
with self.cv:
while self.counts[device_id] < value:
self.cv.wait()
@ -120,7 +119,7 @@ def _clear_shared_memory():
_shared_memory = None
def _allocate_buffer(device_id, memory_space, val):
device_id = tuple(map(int, device_id))
device_id = int(device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
val = np.array(val)
@ -134,7 +133,7 @@ def _allocate_buffer(device_id, memory_space, val):
return np.int16(buffer_id)
def _deallocate_buffer(device_id, memory_space, buffer_id):
device_id = tuple(map(int, device_id))
device_id = int(device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
@ -144,7 +143,7 @@ def _deallocate_buffer(device_id, memory_space, buffer_id):
shared_memory.mem.pop((memory_space, buffer_id, device_id), None)
def _allocate_semaphores(device_id, shape):
device_id = tuple(map(int, device_id))
device_id = int(device_id)
shape = tuple(map(int, shape))
num_semaphores = math.prod(shape)
@ -176,7 +175,7 @@ TPU_MEMORY_SPACE_IDXS[None] = (
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM])
def get_barrier_semaphore(device_id, collective_id):
device_id = tuple(map(int, device_id))
del device_id
collective_id = int(collective_id)
# TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict
@ -195,7 +194,9 @@ def _transform_slice_or_index(slice_or_idx):
return slice_or_idx
else:
start, size, stride = (
slice_or_idx.start, slice_or_idx.size, slice_or_idx.stride)
int(slice_or_idx.start),
int(slice_or_idx.size),
int(slice_or_idx.stride))
return slice(start, start + size * stride, stride)
def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
@ -234,7 +235,7 @@ def _to_range(transforms) -> tuple[slice | int, ...]:
return ret
def get(device_id, memory_space, buffer_id, transforms):
device_id = tuple(int(x) for x in device_id)
device_id = int(device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
try:
@ -244,12 +245,21 @@ def get(device_id, memory_space, buffer_id, transforms):
shared_memory = _get_shared_memory()
with shared_memory.lock:
return shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
].copy()
read_range = _to_range(transforms)
buffer = shared_memory.mem[(memory_space, buffer_id, device_id)]
ret = buffer[read_range].copy()
if transforms:
# TODO(jburnim): Instead of using NDIndexer, do the computation ourselves
# with buffer.shape and read_range?
expected_shape = transforms[-1].get_indexer_shape()
if expected_shape != ret.shape[:len(expected_shape)]:
raise ValueError(
f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): '
f'reading [{read_range}] but bufer has shape {buffer.shape} .')
return ret
def store(device_id, memory_space, buffer_id, transforms, val):
device_id = tuple(int(x) for x in device_id)
device_id = int(device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
try:
@ -260,15 +270,18 @@ def store(device_id, memory_space, buffer_id, transforms, val):
shared_memory = _get_shared_memory()
with shared_memory.lock:
if transforms:
shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
] = val
else:
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
buff = shared_memory.mem[(memory_space, buffer_id, device_id)]
write_range = _to_range(transforms)
# TODO(jburnim): Better error message if this raises?
in_bounds_shape = buff[write_range].shape
if in_bounds_shape != val.shape:
raise ValueError(
f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): '
f'writing [{write_range}] but buffer has shape {buff.shape} .')
buff[write_range] = val
def swap(device_id, memory_space, buffer_id, transforms, val):
device_id = tuple(int(x) for x in device_id)
def swap(device_id, memory_space, buffer_id, transforms, val, mask):
device_id = int(device_id)
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
buffer_id = int(buffer_id)
try:
@ -276,17 +289,42 @@ def swap(device_id, memory_space, buffer_id, transforms, val):
except:
raise ValueError('Advanced indexers are not supported on TPU')
val = np.array(val)
mask = np.array(mask) if mask is not None else None
if mask is not None:
assert mask.shape == val.shape
shared_memory = _get_shared_memory()
with shared_memory.lock:
result = shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
].copy()
shared_memory.mem[(memory_space, buffer_id, device_id)][
_to_range(transforms)
] = val
buff = shared_memory.mem[(memory_space, buffer_id, device_id)]
read_write_range = _to_range(transforms)
# TODO(jburnim): Better error message if this raises?
raw_result = buff[read_write_range]
in_bounds_shape = raw_result.shape
if mask is None:
if in_bounds_shape != val.shape:
raise ValueError(
f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): '
f'swapping [{read_write_range}] but buffer has shape {buff.shape} .')
buff[read_write_range] = val
return raw_result.copy()
return np.array(result)
in_bounds_mask = np.full(mask.shape, True)
for i in range(len(in_bounds_shape)):
in_bounds_mask[in_bounds_shape[i]:] = False
if (~in_bounds_mask & mask).any():
# TODO(jburnim): Include indices of out-of-bounds locations where mask
# is True.
raise ValueError(
f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): '
f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ')
in_bounds_idx = tuple(slice(i) for i in in_bounds_shape)
result = val.copy()
result[in_bounds_idx] = np.where(
mask[in_bounds_idx], raw_result, val[in_bounds_idx])
buff[read_write_range] = np.where(
mask[in_bounds_idx], val[in_bounds_idx], raw_result)
return result
def execute_dma(src, dst, send_sem, recv_sem):
# NOTE: `src` is a list of arguments for `get` (device_id, memory_space,
@ -310,7 +348,7 @@ def execute_dma(src, dst, send_sem, recv_sem):
recv_sem.signal(data_size, device_id=dst[0])
def print_memory(device_id):
device_id = tuple(map(int, device_id))
device_id = int(device_id)
if all(d == 0 for d in device_id):
shared_memory = _get_shared_memory()
with shared_memory.lock:
@ -321,7 +359,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms,
dst_sem,
src_sem,
dst_device_id):
device_id = tuple(int(x) for x in device_id)
device_id = int(device_id)
src_memory_space, src_id = int(src_memory_space), int(src_id)
src_transforms = jax.tree.map(int, src_transforms)
dst_memory_space, dst_id = int(dst_memory_space), int(dst_id)
@ -330,7 +368,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms,
if src_sem is not None:
src_sem = int(src_sem)
if dst_device_id is not None:
dst_device_id = tuple(int(x) for x in dst_device_id)
dst_device_id = int(dst_device_id)
else:
dst_device_id = device_id
@ -349,7 +387,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms,
dst_sem)
def dma_wait(device_id, sem, size):
device_id = tuple(int(x) for x in device_id)
device_id = int(device_id)
sem = int(sem)
size = int(size)
@ -359,13 +397,16 @@ def dma_wait(device_id, sem, size):
sem.wait(size, device_id)
def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index):
device_id = tuple(map(int, device_id))
device_id = int(device_id)
sem = int(sem)
inc = int(inc)
target_device_id = tuple(map(int, target_device_id))
if target_device_id is None:
target_device_id = device_id
else:
target_device_id = int(target_device_id)
if target_core_index is not None:
raise NotImplementedError()
raise NotImplementedError('semaphore_signal with target_core_index')
shared_memory = _get_shared_memory()
with shared_memory.lock:
@ -373,7 +414,7 @@ def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index):
sem.signal(inc, target_device_id)
def semaphore_wait(device_id, sem, value):
device_id = tuple(map(int, device_id))
device_id = int(device_id)
sem = int(sem)
value = int(value)
@ -390,6 +431,26 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
dtype = transform.transform_dtype(dtype)
return shape, dtype
def _device_coords_to_logical_id(device_coords, axis_sizes):
if not isinstance(device_coords, tuple):
device_coords = (device_coords,)
assert len(device_coords) == len(axis_sizes)
sizes = list(axis_sizes.values())
ret = 0
for i in range(len(device_coords)):
ret += device_coords[i] * math.prod(sizes[i+1:])
return ret
def _device_id_to_logical(device_id, device_id_type, axis_sizes):
if device_id is None:
return None
if device_id_type == mosaic_primitives.DeviceIdType.MESH:
return _device_coords_to_logical_id(device_id, axis_sizes)
elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL:
return device_id
else:
raise ValueError(f'Unsupported device ID type: {device_id_type}')
@lu.cache
def _to_jaxpr(flat_fun, in_avals):
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
@ -414,10 +475,11 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
# Get the mesh coordinates.
device_coords = tuple(
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
# TODO(jburnim): Convert to a single integer device ID.
# Get the device ID.
axis_sizes = jax_core.get_axis_env().axis_sizes
device_id = _device_coords_to_logical_id(
tuple(lax.axis_index(s) for s in axis_sizes.keys()),
axis_sizes)
# TODO(jburnim): Pass the device ID around, instead of re-fetching/computing
# it for each sub-jaxpr.
@ -431,10 +493,32 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
invals = jax.util.safe_map(read, eqn.invars)
if prim is primitives.load_p:
raise NotImplementedError()
(ref, transforms, mask, _) = jax.tree.unflatten(
eqn.params['args_tree'], invals)
if mask is not None:
raise NotImplementedError('masked load_p')
out = callback.io_callback(
get,
eqn.outvars[0].aval,
device_id,
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
ref,
transforms,
ordered=True)
elif prim is primitives.swap_p:
raise NotImplementedError()
(ref, transforms, val, mask) = jax.tree.unflatten(
eqn.params['args_tree'], invals)
out = callback.io_callback(
swap,
eqn.outvars[0].aval,
device_id,
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
ref,
transforms,
val,
mask,
ordered=True)
elif prim is lax.cond_p:
def _make_branch(jaxpr):
@ -470,15 +554,18 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
compiler_params=compiler_params),
init_vals)
elif prim is for_loop.for_p:
raise NotImplementedError('for_p')
elif prim is pjit.pjit_p:
pjit_jaxpr = eqn.params['jaxpr']
def f(*args):
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
def f(*args, jaxpr):
return _interpret_jaxpr(jaxpr.jaxpr, *jaxpr.consts, *args,
compiler_params=compiler_params)
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
new_jaxpr = _to_jaxpr(lu.wrap_init(f,
debug_info=pjit_jaxpr.jaxpr.debug_info),
in_avals)
new_jaxpr = _to_jaxpr(
lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']),
debug_info=eqn.params['jaxpr'].jaxpr.debug_info),
in_avals)
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
elif prim is primitives.run_scoped_p:
@ -490,14 +577,14 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
allocs.append(callback.io_callback(
_allocate_semaphores,
jax.ShapeDtypeStruct(v.aval.shape, jnp.int16),
device_coords,
device_id,
v.aval.shape,
ordered=True))
else:
allocs.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
ordered=True))
@ -510,7 +597,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
callback.io_callback(
_deallocate_buffer,
None,
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
a,
ordered=True)
@ -519,7 +606,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
# callback.io_callback(
# _deallocate_semaphores,
# None,
# device_coords,
# device_id,
# a,
# ordered=True)
pass
@ -528,7 +615,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
out = callback.io_callback(
get,
eqn.outvars[0].aval,
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
invals[0],
jax.tree.unflatten(eqn.params['tree'], invals[1:]),
@ -538,11 +625,12 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
out = callback.io_callback(
swap,
eqn.outvars[0].aval,
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
invals[0],
jax.tree.unflatten(eqn.params['tree'], invals[2:]),
invals[1],
None,
ordered=True)
elif prim is mosaic_primitives.dma_start_p:
@ -550,20 +638,22 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
dst, dst_transforms,
dst_sem, dst_sem_transforms,
src_sem, src_sem_transforms,
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
target_device_id = _device_id_to_logical(
target_device_id, eqn.params['device_id_type'], axis_sizes)
(orig_src_ref, _, orig_dst_ref, *_
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
callback.io_callback(
dma_start,
(),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space],
src, src_transforms,
TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space],
dst, dst_transforms,
state_discharge.transform_array(dst_sem, dst_sem_transforms),
state_discharge.transform_array(src_sem, src_sem_transforms),
device_id,
target_device_id,
ordered=True)
out = []
@ -572,13 +662,13 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
dst, dst_transforms,
dst_sem, dst_sem_transforms,
src_sem, src_sem_transforms,
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
read_shape, read_dtype = _compute_transformed_shape_and_dtype(
eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms)
callback.io_callback(
dma_wait,
(),
device_coords,
device_id,
state_discharge.transform_array(dst_sem, dst_sem_transforms),
math.prod(read_shape) * read_dtype.itemsize,
ordered=True)
@ -588,20 +678,22 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
out = callback.io_callback(
get_barrier_semaphore,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
device_id,
compiler_params['mosaic']['collective_id'],
ordered=True)
elif prim is mosaic_primitives.semaphore_signal_p:
sem, sem_transforms, inc, device_id, core_index = (
sem, sem_transforms, inc, target_device_id, core_index = (
jax.tree.unflatten(eqn.params['args_tree'], invals))
target_device_id = _device_id_to_logical(
target_device_id, eqn.params['device_id_type'], axis_sizes)
callback.io_callback(
semaphore_signal,
(),
device_coords,
device_id,
state_discharge.transform_array(sem, sem_transforms),
inc,
device_id,
target_device_id,
core_index,
ordered=True)
out = []
@ -612,14 +704,21 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
callback.io_callback(
semaphore_wait,
(),
device_coords,
device_id,
state_discharge.transform_array(sem, sem_transforms),
value,
ordered=True)
out = []
elif prim is primitives.atomic_rmw_p:
raise NotImplementedError('atomic_rmw_p')
elif prim is primitives.atomic_cas_p:
raise NotImplementedError('atomic_cas_p')
else:
out = prim.bind(*invals, **eqn.params)
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
out = prim.bind(*subfuns, *invals, **bind_params)
out = out if prim.multiple_results else [out]
jax.util.safe_map(write, eqn.outvars, out)
@ -668,6 +767,29 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims)
def _pad_to_block_dimension(value, block_shape):
"""Pads values so the shape evenly divides into block dimensions.
For example, if values has a shape of (33, 2, 5) with a block_shape of
(32, 2, 4), this function will pad the value of shape to (64, 2, 8).
Args:
value: Array to be padded.
block_shape: Block shapes to use for padding. If None, no padding will
be performed.
Returns:
A padded array.
"""
padded_shape = tuple(
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
)
if padded_shape != value.shape:
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
value = jnp.pad(value, pad_width, constant_values=pad_value)
return value
def get_interpret_effects():
return {callback._OrderedIOEffect}
@ -692,66 +814,93 @@ def interpret_pallas_call(
# TODO(jburnim): Support dynamic grid sizes?
grid = grid_mapping.static_grid
device_coords = tuple(
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
axis_sizes = jax_core.get_axis_env().axis_sizes
device_id = _device_coords_to_logical_id(
tuple(lax.axis_index(s) for s in axis_sizes.keys()),
axis_sizes)
# Pad input arguments.
is_indexing_dim = [
tuple(b is pallas_core.mapped for b in bm.block_shape)
for bm in grid_mapping.block_mappings
]
block_shapes = [
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
]
num_inputs = grid_mapping.num_inputs
input_args = [
_pad_to_block_dimension(a, bs)
for a, bs in zip(input_args, block_shapes[:num_inputs])
]
# Allocate buffers in HBM for outputs.
io_alias_map = dict(input_output_aliases)
output_buffer_ids = []
output_buffer_shapes = []
output_vals = _initialize_output_vals(
grid_mapping.block_mappings_output, args, input_output_aliases)
for out_val in output_vals:
num_outputs = grid_mapping.num_outputs
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
for out_val, bs in zip(output_vals, output_block_shapes):
padded_val = _pad_to_block_dimension(out_val, bs)
output_buffer_shapes.append(padded_val.shape)
output_buffer_ids.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
out_val,
padded_val,
ordered=True))
# Allocate buffers for all kernel arguments (e.g., scalars, inputs,
# outputs, scratch).
io_alias_map = dict(input_output_aliases)
oi_alias_map = {v: k for k, v in input_output_aliases}
kernel_buffer_ids = []
for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars):
kernel_buffer_ids.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM],
val,
ordered=True))
for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]):
output_idx = i - grid_mapping.num_inputs
is_input = i < grid_mapping.num_inputs
is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs)
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
kernel_buffer_ids.append(callback.io_callback(
_allocate_semaphores,
jax.ShapeDtypeStruct(var.aval.shape, jnp.int16),
device_coords,
device_id,
var.aval.shape,
ordered=True))
elif is_output and _is_any(var.aval.memory_space):
# Don't allocate a buffer -- use the already-allocated HBM output buffer.
# Use the already-allocated HBM output buffer.
#
# TODO(jburnim): For kernel args in HBM, check that block shape is the
# same as for the corresponding pallas_call input, and that the index_map
# is trivial.
kernel_buffer_ids.append(output_buffer_ids[output_idx])
elif (i < grid_mapping.num_inputs) and (i in io_alias_map):
# Instead of allocating a new buffer, use the already-allocated
# HBM output buffer.
assert _is_any(var.aval.memory_space)
elif is_output and (output_idx in oi_alias_map):
# Use the already-allocated (non-HBM) input buffer.
kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]])
elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space):
# Use the already-allocated HBM output buffer.
kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]])
else:
# TODO(jburnim): For kernel args in HBM, check that block shape is the
# same as for the corresponding pallas_call input, and that the index_map
# is trivial.
kernel_buffer_ids.append(callback.io_callback(
_allocate_buffer,
jax.ShapeDtypeStruct((), jnp.int16),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
ordered=True))
num_inputs = grid_mapping.num_inputs
_, input_ids, kernel_output_ids, _ = split_list(
kernel_buffer_ids,
[grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs])
@ -769,27 +918,19 @@ def interpret_pallas_call(
callback.io_callback(
store,
(),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
buffer_id,
(),
val,
ordered=True)
is_indexing_dim = [
tuple(b is pallas_core.mapped for b in bm.block_shape)
for bm in grid_mapping.block_mappings
]
block_shapes = [
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
]
scalar_ids, in_out_ids, scratch_ids = split_list(
kernel_buffer_ids,
[grid_mapping.num_index_operands, len(grid_mapping.block_mappings)])
if grid:
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type]
else:
# Base case is always one iteration when grid is ()
num_iterations = 1
@ -824,7 +965,7 @@ def interpret_pallas_call(
callback.io_callback(
store,
(),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
input_ids[j],
(),
@ -845,21 +986,22 @@ def interpret_pallas_call(
kernel_output_val = callback.io_callback(
get,
var.aval,
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
kernel_output_ids[j],
(),
ordered=True)
transform = indexing.NDIndexer(
indices=tuple(indexing.ds(st, sz)
for st, sz in zip(start_indices[num_inputs + j],
block_shapes[num_inputs + j])),
indices=tuple(indexing.ds(st, sz) if not iid else st
for st, sz, iid in zip(start_indices[num_inputs + j],
block_shapes[num_inputs + j],
is_indexing_dim[num_inputs + j])),
shape=output_vals[j].shape,
int_indexer_shape=())
callback.io_callback(
store,
(),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
output_buffer_ids[j],
(transform,),
@ -880,19 +1022,22 @@ def interpret_pallas_call(
callback.io_callback(
get,
val,
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
output_buffer_id,
(),
(indexing.NDIndexer.from_indices_shape(
tuple(indexing.ds(0, s) for s in val.shape),
output_buffer_shape),),
ordered=True)
for val, output_buffer_id in zip(output_vals, output_buffer_ids)
for val, output_buffer_id, output_buffer_shape in zip(
output_vals, output_buffer_ids, output_buffer_shapes)
]
for buffer_id in output_buffer_ids:
callback.io_callback(
_deallocate_buffer,
(),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
buffer_id,
ordered=True)
@ -903,7 +1048,7 @@ def interpret_pallas_call(
callback.io_callback(
_deallocate_buffer,
(),
device_coords,
device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
buffer_id,
ordered=True)