mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] Several fixes/improvements for the new TPU interpret mode.
- Checks bounds for reads and writes to shared memory. - Pads kernel arguments when necessary. - Fix support for input-output aliasing. - Fix handling of vmap'ed dimensions. - Supports un-masked `pl.load` and masked or un-masked `pl.swap`. - Switch to using single integer device IDs instead of tuples. - Better error messages for unsupported primitives: `for_p`, `atomic_rmw_p`, and `atomic_cas_p` . PiperOrigin-RevId: 727301519
This commit is contained in:
parent
eaceac3bf9
commit
962eb41933
@ -15,7 +15,7 @@
|
||||
import collections
|
||||
from collections.abc import Iterable, Sequence
|
||||
import dataclasses
|
||||
from functools import reduce
|
||||
import functools
|
||||
import math
|
||||
import threading
|
||||
from typing import Any
|
||||
@ -24,6 +24,7 @@ import jax
|
||||
from jax import lax
|
||||
from jax._src import callback
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.pallas.mosaic import primitives as mosaic_primitives
|
||||
from jax._src.pallas.mosaic import core as mosaic_core
|
||||
@ -72,13 +73,11 @@ class Semaphore:
|
||||
self.counts = collections.defaultdict(int)
|
||||
|
||||
def signal(self, inc, device_id):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
with self.cv:
|
||||
self.counts[device_id] += inc
|
||||
self.cv.notify_all()
|
||||
|
||||
def wait(self, value, device_id):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
with self.cv:
|
||||
while self.counts[device_id] < value:
|
||||
self.cv.wait()
|
||||
@ -120,7 +119,7 @@ def _clear_shared_memory():
|
||||
_shared_memory = None
|
||||
|
||||
def _allocate_buffer(device_id, memory_space, val):
|
||||
device_id = tuple(map(int, device_id))
|
||||
device_id = int(device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
val = np.array(val)
|
||||
|
||||
@ -134,7 +133,7 @@ def _allocate_buffer(device_id, memory_space, val):
|
||||
return np.int16(buffer_id)
|
||||
|
||||
def _deallocate_buffer(device_id, memory_space, buffer_id):
|
||||
device_id = tuple(map(int, device_id))
|
||||
device_id = int(device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
|
||||
@ -144,7 +143,7 @@ def _deallocate_buffer(device_id, memory_space, buffer_id):
|
||||
shared_memory.mem.pop((memory_space, buffer_id, device_id), None)
|
||||
|
||||
def _allocate_semaphores(device_id, shape):
|
||||
device_id = tuple(map(int, device_id))
|
||||
device_id = int(device_id)
|
||||
shape = tuple(map(int, shape))
|
||||
num_semaphores = math.prod(shape)
|
||||
|
||||
@ -176,7 +175,7 @@ TPU_MEMORY_SPACE_IDXS[None] = (
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM])
|
||||
|
||||
def get_barrier_semaphore(device_id, collective_id):
|
||||
device_id = tuple(map(int, device_id))
|
||||
del device_id
|
||||
collective_id = int(collective_id)
|
||||
|
||||
# TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict
|
||||
@ -195,7 +194,9 @@ def _transform_slice_or_index(slice_or_idx):
|
||||
return slice_or_idx
|
||||
else:
|
||||
start, size, stride = (
|
||||
slice_or_idx.start, slice_or_idx.size, slice_or_idx.stride)
|
||||
int(slice_or_idx.start),
|
||||
int(slice_or_idx.size),
|
||||
int(slice_or_idx.stride))
|
||||
return slice(start, start + size * stride, stride)
|
||||
|
||||
def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
|
||||
@ -234,7 +235,7 @@ def _to_range(transforms) -> tuple[slice | int, ...]:
|
||||
return ret
|
||||
|
||||
def get(device_id, memory_space, buffer_id, transforms):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
device_id = int(device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
try:
|
||||
@ -244,12 +245,21 @@ def get(device_id, memory_space, buffer_id, transforms):
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
return shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
].copy()
|
||||
read_range = _to_range(transforms)
|
||||
buffer = shared_memory.mem[(memory_space, buffer_id, device_id)]
|
||||
ret = buffer[read_range].copy()
|
||||
if transforms:
|
||||
# TODO(jburnim): Instead of using NDIndexer, do the computation ourselves
|
||||
# with buffer.shape and read_range?
|
||||
expected_shape = transforms[-1].get_indexer_shape()
|
||||
if expected_shape != ret.shape[:len(expected_shape)]:
|
||||
raise ValueError(
|
||||
f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): '
|
||||
f'reading [{read_range}] but bufer has shape {buffer.shape} .')
|
||||
return ret
|
||||
|
||||
def store(device_id, memory_space, buffer_id, transforms, val):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
device_id = int(device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
try:
|
||||
@ -260,15 +270,18 @@ def store(device_id, memory_space, buffer_id, transforms, val):
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
if transforms:
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
] = val
|
||||
else:
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
|
||||
buff = shared_memory.mem[(memory_space, buffer_id, device_id)]
|
||||
write_range = _to_range(transforms)
|
||||
# TODO(jburnim): Better error message if this raises?
|
||||
in_bounds_shape = buff[write_range].shape
|
||||
if in_bounds_shape != val.shape:
|
||||
raise ValueError(
|
||||
f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): '
|
||||
f'writing [{write_range}] but buffer has shape {buff.shape} .')
|
||||
buff[write_range] = val
|
||||
|
||||
def swap(device_id, memory_space, buffer_id, transforms, val):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
def swap(device_id, memory_space, buffer_id, transforms, val, mask):
|
||||
device_id = int(device_id)
|
||||
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
||||
buffer_id = int(buffer_id)
|
||||
try:
|
||||
@ -276,17 +289,42 @@ def swap(device_id, memory_space, buffer_id, transforms, val):
|
||||
except:
|
||||
raise ValueError('Advanced indexers are not supported on TPU')
|
||||
val = np.array(val)
|
||||
mask = np.array(mask) if mask is not None else None
|
||||
if mask is not None:
|
||||
assert mask.shape == val.shape
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
result = shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
].copy()
|
||||
shared_memory.mem[(memory_space, buffer_id, device_id)][
|
||||
_to_range(transforms)
|
||||
] = val
|
||||
buff = shared_memory.mem[(memory_space, buffer_id, device_id)]
|
||||
read_write_range = _to_range(transforms)
|
||||
# TODO(jburnim): Better error message if this raises?
|
||||
raw_result = buff[read_write_range]
|
||||
in_bounds_shape = raw_result.shape
|
||||
if mask is None:
|
||||
if in_bounds_shape != val.shape:
|
||||
raise ValueError(
|
||||
f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): '
|
||||
f'swapping [{read_write_range}] but buffer has shape {buff.shape} .')
|
||||
buff[read_write_range] = val
|
||||
return raw_result.copy()
|
||||
|
||||
return np.array(result)
|
||||
in_bounds_mask = np.full(mask.shape, True)
|
||||
for i in range(len(in_bounds_shape)):
|
||||
in_bounds_mask[in_bounds_shape[i]:] = False
|
||||
if (~in_bounds_mask & mask).any():
|
||||
# TODO(jburnim): Include indices of out-of-bounds locations where mask
|
||||
# is True.
|
||||
raise ValueError(
|
||||
f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): '
|
||||
f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ')
|
||||
|
||||
in_bounds_idx = tuple(slice(i) for i in in_bounds_shape)
|
||||
result = val.copy()
|
||||
result[in_bounds_idx] = np.where(
|
||||
mask[in_bounds_idx], raw_result, val[in_bounds_idx])
|
||||
buff[read_write_range] = np.where(
|
||||
mask[in_bounds_idx], val[in_bounds_idx], raw_result)
|
||||
return result
|
||||
|
||||
def execute_dma(src, dst, send_sem, recv_sem):
|
||||
# NOTE: `src` is a list of arguments for `get` (device_id, memory_space,
|
||||
@ -310,7 +348,7 @@ def execute_dma(src, dst, send_sem, recv_sem):
|
||||
recv_sem.signal(data_size, device_id=dst[0])
|
||||
|
||||
def print_memory(device_id):
|
||||
device_id = tuple(map(int, device_id))
|
||||
device_id = int(device_id)
|
||||
if all(d == 0 for d in device_id):
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
@ -321,7 +359,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms,
|
||||
dst_sem,
|
||||
src_sem,
|
||||
dst_device_id):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
device_id = int(device_id)
|
||||
src_memory_space, src_id = int(src_memory_space), int(src_id)
|
||||
src_transforms = jax.tree.map(int, src_transforms)
|
||||
dst_memory_space, dst_id = int(dst_memory_space), int(dst_id)
|
||||
@ -330,7 +368,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms,
|
||||
if src_sem is not None:
|
||||
src_sem = int(src_sem)
|
||||
if dst_device_id is not None:
|
||||
dst_device_id = tuple(int(x) for x in dst_device_id)
|
||||
dst_device_id = int(dst_device_id)
|
||||
else:
|
||||
dst_device_id = device_id
|
||||
|
||||
@ -349,7 +387,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms,
|
||||
dst_sem)
|
||||
|
||||
def dma_wait(device_id, sem, size):
|
||||
device_id = tuple(int(x) for x in device_id)
|
||||
device_id = int(device_id)
|
||||
sem = int(sem)
|
||||
size = int(size)
|
||||
|
||||
@ -359,13 +397,16 @@ def dma_wait(device_id, sem, size):
|
||||
sem.wait(size, device_id)
|
||||
|
||||
def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index):
|
||||
device_id = tuple(map(int, device_id))
|
||||
device_id = int(device_id)
|
||||
sem = int(sem)
|
||||
inc = int(inc)
|
||||
target_device_id = tuple(map(int, target_device_id))
|
||||
if target_device_id is None:
|
||||
target_device_id = device_id
|
||||
else:
|
||||
target_device_id = int(target_device_id)
|
||||
|
||||
if target_core_index is not None:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError('semaphore_signal with target_core_index')
|
||||
|
||||
shared_memory = _get_shared_memory()
|
||||
with shared_memory.lock:
|
||||
@ -373,7 +414,7 @@ def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index):
|
||||
sem.signal(inc, target_device_id)
|
||||
|
||||
def semaphore_wait(device_id, sem, value):
|
||||
device_id = tuple(map(int, device_id))
|
||||
device_id = int(device_id)
|
||||
sem = int(sem)
|
||||
value = int(value)
|
||||
|
||||
@ -390,6 +431,26 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
|
||||
dtype = transform.transform_dtype(dtype)
|
||||
return shape, dtype
|
||||
|
||||
def _device_coords_to_logical_id(device_coords, axis_sizes):
|
||||
if not isinstance(device_coords, tuple):
|
||||
device_coords = (device_coords,)
|
||||
assert len(device_coords) == len(axis_sizes)
|
||||
sizes = list(axis_sizes.values())
|
||||
ret = 0
|
||||
for i in range(len(device_coords)):
|
||||
ret += device_coords[i] * math.prod(sizes[i+1:])
|
||||
return ret
|
||||
|
||||
def _device_id_to_logical(device_id, device_id_type, axis_sizes):
|
||||
if device_id is None:
|
||||
return None
|
||||
if device_id_type == mosaic_primitives.DeviceIdType.MESH:
|
||||
return _device_coords_to_logical_id(device_id, axis_sizes)
|
||||
elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL:
|
||||
return device_id
|
||||
else:
|
||||
raise ValueError(f'Unsupported device ID type: {device_id_type}')
|
||||
|
||||
@lu.cache
|
||||
def _to_jaxpr(flat_fun, in_avals):
|
||||
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
@ -414,10 +475,11 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
|
||||
jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
|
||||
|
||||
# Get the mesh coordinates.
|
||||
device_coords = tuple(
|
||||
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
|
||||
# TODO(jburnim): Convert to a single integer device ID.
|
||||
# Get the device ID.
|
||||
axis_sizes = jax_core.get_axis_env().axis_sizes
|
||||
device_id = _device_coords_to_logical_id(
|
||||
tuple(lax.axis_index(s) for s in axis_sizes.keys()),
|
||||
axis_sizes)
|
||||
# TODO(jburnim): Pass the device ID around, instead of re-fetching/computing
|
||||
# it for each sub-jaxpr.
|
||||
|
||||
@ -431,10 +493,32 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
invals = jax.util.safe_map(read, eqn.invars)
|
||||
|
||||
if prim is primitives.load_p:
|
||||
raise NotImplementedError()
|
||||
(ref, transforms, mask, _) = jax.tree.unflatten(
|
||||
eqn.params['args_tree'], invals)
|
||||
if mask is not None:
|
||||
raise NotImplementedError('masked load_p')
|
||||
out = callback.io_callback(
|
||||
get,
|
||||
eqn.outvars[0].aval,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
||||
ref,
|
||||
transforms,
|
||||
ordered=True)
|
||||
|
||||
elif prim is primitives.swap_p:
|
||||
raise NotImplementedError()
|
||||
(ref, transforms, val, mask) = jax.tree.unflatten(
|
||||
eqn.params['args_tree'], invals)
|
||||
out = callback.io_callback(
|
||||
swap,
|
||||
eqn.outvars[0].aval,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
||||
ref,
|
||||
transforms,
|
||||
val,
|
||||
mask,
|
||||
ordered=True)
|
||||
|
||||
elif prim is lax.cond_p:
|
||||
def _make_branch(jaxpr):
|
||||
@ -470,15 +554,18 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
compiler_params=compiler_params),
|
||||
init_vals)
|
||||
|
||||
elif prim is for_loop.for_p:
|
||||
raise NotImplementedError('for_p')
|
||||
|
||||
elif prim is pjit.pjit_p:
|
||||
pjit_jaxpr = eqn.params['jaxpr']
|
||||
def f(*args):
|
||||
return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args,
|
||||
def f(*args, jaxpr):
|
||||
return _interpret_jaxpr(jaxpr.jaxpr, *jaxpr.consts, *args,
|
||||
compiler_params=compiler_params)
|
||||
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
|
||||
new_jaxpr = _to_jaxpr(lu.wrap_init(f,
|
||||
debug_info=pjit_jaxpr.jaxpr.debug_info),
|
||||
in_avals)
|
||||
new_jaxpr = _to_jaxpr(
|
||||
lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']),
|
||||
debug_info=eqn.params['jaxpr'].jaxpr.debug_info),
|
||||
in_avals)
|
||||
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
|
||||
|
||||
elif prim is primitives.run_scoped_p:
|
||||
@ -490,14 +577,14 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
allocs.append(callback.io_callback(
|
||||
_allocate_semaphores,
|
||||
jax.ShapeDtypeStruct(v.aval.shape, jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
v.aval.shape,
|
||||
ordered=True))
|
||||
else:
|
||||
allocs.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
||||
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
|
||||
ordered=True))
|
||||
@ -510,7 +597,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
callback.io_callback(
|
||||
_deallocate_buffer,
|
||||
None,
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
||||
a,
|
||||
ordered=True)
|
||||
@ -519,7 +606,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
# callback.io_callback(
|
||||
# _deallocate_semaphores,
|
||||
# None,
|
||||
# device_coords,
|
||||
# device_id,
|
||||
# a,
|
||||
# ordered=True)
|
||||
pass
|
||||
@ -528,7 +615,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
out = callback.io_callback(
|
||||
get,
|
||||
eqn.outvars[0].aval,
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
||||
invals[0],
|
||||
jax.tree.unflatten(eqn.params['tree'], invals[1:]),
|
||||
@ -538,11 +625,12 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
out = callback.io_callback(
|
||||
swap,
|
||||
eqn.outvars[0].aval,
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
||||
invals[0],
|
||||
jax.tree.unflatten(eqn.params['tree'], invals[2:]),
|
||||
invals[1],
|
||||
None,
|
||||
ordered=True)
|
||||
|
||||
elif prim is mosaic_primitives.dma_start_p:
|
||||
@ -550,20 +638,22 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
dst, dst_transforms,
|
||||
dst_sem, dst_sem_transforms,
|
||||
src_sem, src_sem_transforms,
|
||||
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
|
||||
target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
|
||||
target_device_id = _device_id_to_logical(
|
||||
target_device_id, eqn.params['device_id_type'], axis_sizes)
|
||||
(orig_src_ref, _, orig_dst_ref, *_
|
||||
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
|
||||
callback.io_callback(
|
||||
dma_start,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space],
|
||||
src, src_transforms,
|
||||
TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space],
|
||||
dst, dst_transforms,
|
||||
state_discharge.transform_array(dst_sem, dst_sem_transforms),
|
||||
state_discharge.transform_array(src_sem, src_sem_transforms),
|
||||
device_id,
|
||||
target_device_id,
|
||||
ordered=True)
|
||||
out = []
|
||||
|
||||
@ -572,13 +662,13 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
dst, dst_transforms,
|
||||
dst_sem, dst_sem_transforms,
|
||||
src_sem, src_sem_transforms,
|
||||
device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
|
||||
target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
|
||||
read_shape, read_dtype = _compute_transformed_shape_and_dtype(
|
||||
eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms)
|
||||
callback.io_callback(
|
||||
dma_wait,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
state_discharge.transform_array(dst_sem, dst_sem_transforms),
|
||||
math.prod(read_shape) * read_dtype.itemsize,
|
||||
ordered=True)
|
||||
@ -588,20 +678,22 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
out = callback.io_callback(
|
||||
get_barrier_semaphore,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
compiler_params['mosaic']['collective_id'],
|
||||
ordered=True)
|
||||
|
||||
elif prim is mosaic_primitives.semaphore_signal_p:
|
||||
sem, sem_transforms, inc, device_id, core_index = (
|
||||
sem, sem_transforms, inc, target_device_id, core_index = (
|
||||
jax.tree.unflatten(eqn.params['args_tree'], invals))
|
||||
target_device_id = _device_id_to_logical(
|
||||
target_device_id, eqn.params['device_id_type'], axis_sizes)
|
||||
callback.io_callback(
|
||||
semaphore_signal,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
state_discharge.transform_array(sem, sem_transforms),
|
||||
inc,
|
||||
device_id,
|
||||
target_device_id,
|
||||
core_index,
|
||||
ordered=True)
|
||||
out = []
|
||||
@ -612,14 +704,21 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params):
|
||||
callback.io_callback(
|
||||
semaphore_wait,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
state_discharge.transform_array(sem, sem_transforms),
|
||||
value,
|
||||
ordered=True)
|
||||
out = []
|
||||
|
||||
elif prim is primitives.atomic_rmw_p:
|
||||
raise NotImplementedError('atomic_rmw_p')
|
||||
|
||||
elif prim is primitives.atomic_cas_p:
|
||||
raise NotImplementedError('atomic_cas_p')
|
||||
|
||||
else:
|
||||
out = prim.bind(*invals, **eqn.params)
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
out = prim.bind(*subfuns, *invals, **bind_params)
|
||||
|
||||
out = out if prim.multiple_results else [out]
|
||||
jax.util.safe_map(write, eqn.outvars, out)
|
||||
@ -668,6 +767,29 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
||||
dtype=np.bool_)])
|
||||
return lax.squeeze(output, squeeze_dims)
|
||||
|
||||
def _pad_to_block_dimension(value, block_shape):
|
||||
"""Pads values so the shape evenly divides into block dimensions.
|
||||
|
||||
For example, if values has a shape of (33, 2, 5) with a block_shape of
|
||||
(32, 2, 4), this function will pad the value of shape to (64, 2, 8).
|
||||
|
||||
Args:
|
||||
value: Array to be padded.
|
||||
block_shape: Block shapes to use for padding. If None, no padding will
|
||||
be performed.
|
||||
|
||||
Returns:
|
||||
A padded array.
|
||||
"""
|
||||
padded_shape = tuple(
|
||||
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
|
||||
)
|
||||
if padded_shape != value.shape:
|
||||
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
||||
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
|
||||
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
||||
return value
|
||||
|
||||
def get_interpret_effects():
|
||||
return {callback._OrderedIOEffect}
|
||||
|
||||
@ -692,66 +814,93 @@ def interpret_pallas_call(
|
||||
# TODO(jburnim): Support dynamic grid sizes?
|
||||
grid = grid_mapping.static_grid
|
||||
|
||||
device_coords = tuple(
|
||||
lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes)
|
||||
axis_sizes = jax_core.get_axis_env().axis_sizes
|
||||
device_id = _device_coords_to_logical_id(
|
||||
tuple(lax.axis_index(s) for s in axis_sizes.keys()),
|
||||
axis_sizes)
|
||||
|
||||
# Pad input arguments.
|
||||
is_indexing_dim = [
|
||||
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
||||
for bm in grid_mapping.block_mappings
|
||||
]
|
||||
block_shapes = [
|
||||
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
||||
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
||||
]
|
||||
num_inputs = grid_mapping.num_inputs
|
||||
input_args = [
|
||||
_pad_to_block_dimension(a, bs)
|
||||
for a, bs in zip(input_args, block_shapes[:num_inputs])
|
||||
]
|
||||
|
||||
# Allocate buffers in HBM for outputs.
|
||||
io_alias_map = dict(input_output_aliases)
|
||||
output_buffer_ids = []
|
||||
output_buffer_shapes = []
|
||||
output_vals = _initialize_output_vals(
|
||||
grid_mapping.block_mappings_output, args, input_output_aliases)
|
||||
for out_val in output_vals:
|
||||
num_outputs = grid_mapping.num_outputs
|
||||
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
|
||||
for out_val, bs in zip(output_vals, output_block_shapes):
|
||||
padded_val = _pad_to_block_dimension(out_val, bs)
|
||||
output_buffer_shapes.append(padded_val.shape)
|
||||
output_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
out_val,
|
||||
padded_val,
|
||||
ordered=True))
|
||||
|
||||
# Allocate buffers for all kernel arguments (e.g., scalars, inputs,
|
||||
# outputs, scratch).
|
||||
io_alias_map = dict(input_output_aliases)
|
||||
oi_alias_map = {v: k for k, v in input_output_aliases}
|
||||
kernel_buffer_ids = []
|
||||
for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars):
|
||||
kernel_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM],
|
||||
val,
|
||||
ordered=True))
|
||||
for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]):
|
||||
output_idx = i - grid_mapping.num_inputs
|
||||
is_input = i < grid_mapping.num_inputs
|
||||
is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs)
|
||||
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
|
||||
kernel_buffer_ids.append(callback.io_callback(
|
||||
_allocate_semaphores,
|
||||
jax.ShapeDtypeStruct(var.aval.shape, jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
var.aval.shape,
|
||||
ordered=True))
|
||||
elif is_output and _is_any(var.aval.memory_space):
|
||||
# Don't allocate a buffer -- use the already-allocated HBM output buffer.
|
||||
# Use the already-allocated HBM output buffer.
|
||||
#
|
||||
# TODO(jburnim): For kernel args in HBM, check that block shape is the
|
||||
# same as for the corresponding pallas_call input, and that the index_map
|
||||
# is trivial.
|
||||
kernel_buffer_ids.append(output_buffer_ids[output_idx])
|
||||
elif (i < grid_mapping.num_inputs) and (i in io_alias_map):
|
||||
# Instead of allocating a new buffer, use the already-allocated
|
||||
# HBM output buffer.
|
||||
assert _is_any(var.aval.memory_space)
|
||||
elif is_output and (output_idx in oi_alias_map):
|
||||
# Use the already-allocated (non-HBM) input buffer.
|
||||
kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]])
|
||||
elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space):
|
||||
# Use the already-allocated HBM output buffer.
|
||||
kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]])
|
||||
else:
|
||||
# TODO(jburnim): For kernel args in HBM, check that block shape is the
|
||||
# same as for the corresponding pallas_call input, and that the index_map
|
||||
# is trivial.
|
||||
kernel_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
|
||||
ordered=True))
|
||||
|
||||
num_inputs = grid_mapping.num_inputs
|
||||
_, input_ids, kernel_output_ids, _ = split_list(
|
||||
kernel_buffer_ids,
|
||||
[grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs])
|
||||
@ -769,27 +918,19 @@ def interpret_pallas_call(
|
||||
callback.io_callback(
|
||||
store,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
buffer_id,
|
||||
(),
|
||||
val,
|
||||
ordered=True)
|
||||
|
||||
is_indexing_dim = [
|
||||
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
||||
for bm in grid_mapping.block_mappings
|
||||
]
|
||||
block_shapes = [
|
||||
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
||||
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
||||
]
|
||||
scalar_ids, in_out_ids, scratch_ids = split_list(
|
||||
kernel_buffer_ids,
|
||||
[grid_mapping.num_index_operands, len(grid_mapping.block_mappings)])
|
||||
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
|
||||
num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type]
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
@ -824,7 +965,7 @@ def interpret_pallas_call(
|
||||
callback.io_callback(
|
||||
store,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
input_ids[j],
|
||||
(),
|
||||
@ -845,21 +986,22 @@ def interpret_pallas_call(
|
||||
kernel_output_val = callback.io_callback(
|
||||
get,
|
||||
var.aval,
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
kernel_output_ids[j],
|
||||
(),
|
||||
ordered=True)
|
||||
transform = indexing.NDIndexer(
|
||||
indices=tuple(indexing.ds(st, sz)
|
||||
for st, sz in zip(start_indices[num_inputs + j],
|
||||
block_shapes[num_inputs + j])),
|
||||
indices=tuple(indexing.ds(st, sz) if not iid else st
|
||||
for st, sz, iid in zip(start_indices[num_inputs + j],
|
||||
block_shapes[num_inputs + j],
|
||||
is_indexing_dim[num_inputs + j])),
|
||||
shape=output_vals[j].shape,
|
||||
int_indexer_shape=())
|
||||
callback.io_callback(
|
||||
store,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
output_buffer_ids[j],
|
||||
(transform,),
|
||||
@ -880,19 +1022,22 @@ def interpret_pallas_call(
|
||||
callback.io_callback(
|
||||
get,
|
||||
val,
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
output_buffer_id,
|
||||
(),
|
||||
(indexing.NDIndexer.from_indices_shape(
|
||||
tuple(indexing.ds(0, s) for s in val.shape),
|
||||
output_buffer_shape),),
|
||||
ordered=True)
|
||||
for val, output_buffer_id in zip(output_vals, output_buffer_ids)
|
||||
for val, output_buffer_id, output_buffer_shape in zip(
|
||||
output_vals, output_buffer_ids, output_buffer_shapes)
|
||||
]
|
||||
|
||||
for buffer_id in output_buffer_ids:
|
||||
callback.io_callback(
|
||||
_deallocate_buffer,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
||||
buffer_id,
|
||||
ordered=True)
|
||||
@ -903,7 +1048,7 @@ def interpret_pallas_call(
|
||||
callback.io_callback(
|
||||
_deallocate_buffer,
|
||||
(),
|
||||
device_coords,
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
buffer_id,
|
||||
ordered=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user