diff --git a/jax/BUILD b/jax/BUILD index 28d32d834..6eda8311a 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -652,6 +652,7 @@ pytype_strict_library( "//jax/_src/pallas", "//jax/_src/pallas/mosaic:core", "//jax/_src/pallas/mosaic:helpers", + "//jax/_src/pallas/mosaic:interpret", "//jax/_src/pallas/mosaic:lowering", "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/mosaic:pipeline", diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 90da668c3..d239fba98 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -148,3 +148,15 @@ py_library( "//jax/_src/pallas", ], ) + +py_library( + name = "interpret", + srcs = ["interpret.py"], + deps = [ + ":core", + ":primitives", + "//jax", + "//jax/_src/lib", + "//jax/_src/pallas", + ] + py_deps("numpy"), +) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py new file mode 100644 index 000000000..aec9326a4 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret.py @@ -0,0 +1,909 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Iterable, Sequence +import dataclasses +from functools import reduce +import math +import threading +from typing import Any + +import jax +from jax import lax +from jax._src import callback +from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src.pallas.mosaic import primitives as mosaic_primitives +from jax._src.pallas.mosaic import core as mosaic_core +from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives +from jax._src import pjit +from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives +from jax._src.util import ( + safe_map, + safe_zip, + split_list +) +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np + + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Grid = pallas_core.Grid +TupleGrid = pallas_core.TupleGrid +GridSpec = pallas_core.GridSpec +BlockMapping = pallas_core.BlockMapping +GridMapping = pallas_core.GridMapping +BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree +NoBlockSpec = pallas_core.NoBlockSpec +no_block_spec = pallas_core.no_block_spec +ScratchShapeTree = pallas_core.ScratchShapeTree +CostEstimate = pallas_core.CostEstimate + + +@dataclasses.dataclass(frozen=True) +class TPUInterpretParams: + pass + + +class Semaphore: + def __init__(self): + self.cv = threading.Condition() + + # TODO(jburnim): Make this an array. + self.counts = collections.defaultdict(int) + + def signal(self, inc, device_id): + device_id = tuple(int(x) for x in device_id) + with self.cv: + self.counts[device_id] += inc + self.cv.notify_all() + + def wait(self, value, device_id): + device_id = tuple(int(x) for x in device_id) + with self.cv: + while self.counts[device_id] < value: + self.cv.wait() + self.counts[device_id] -= value + + +@dataclasses.dataclass(frozen=True) +class SharedMemory: + # (memory_space, buffer_id, device_id) -> NumPy array + # TODO(jburnim): Handle Megacore. + mem: dict = dataclasses.field(default_factory=dict) + + # semaphore_id -> Semaphore + sem: dict = dataclasses.field(default_factory=dict) + + lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + next_buffer_id: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(lambda: 100)) + next_semaphore_id: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(lambda: 2000)) + +# TODO(jburnim): Do we want to support multiple instances of SharedMemory? +# Maybe for running multiple distinct interpreted computations in parallel? +_shared_memory = None +_shared_memory_init_lock = threading.Lock() + +def _get_shared_memory() -> SharedMemory: + global _shared_memory + if _shared_memory is None: + with _shared_memory_init_lock: + if _shared_memory is None: + _shared_memory = SharedMemory() + return _shared_memory + +def _clear_shared_memory(): + global _shared_memory + with _shared_memory_init_lock: + _shared_memory = None + +def _allocate_buffer(device_id, memory_space, val): + device_id = tuple(map(int, device_id)) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + val = np.array(val) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + buffer_id = shared_memory.next_buffer_id[device_id] + shared_memory.next_buffer_id[device_id] = buffer_id + 1 + shared_memory.mem[(memory_space, buffer_id, device_id)] = val + + # TODO(jburnim): Raise an error if buffer_id is too big for int16. + return np.int16(buffer_id) + +def _deallocate_buffer(device_id, memory_space, buffer_id): + device_id = tuple(map(int, device_id)) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + # TODO(jburnim): Error if buffer doesn't exist? + shared_memory.mem.pop((memory_space, buffer_id, device_id), None) + +def _allocate_semaphores(device_id, shape): + device_id = tuple(map(int, device_id)) + shape = tuple(map(int, shape)) + num_semaphores = math.prod(shape) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + semaphore_id = shared_memory.next_semaphore_id[device_id] + shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores + for i in range(semaphore_id, semaphore_id + num_semaphores): + if not i in shared_memory.sem: + shared_memory.sem[i] = Semaphore() + + # NOTE: For now, we use a relatively uncommon datatype (int16) for + # semaphore (and buffer) IDs, so these values are more easily identifiable + # in kernels. + # + # TODO(jburnim): Raise an error if any IDs are too big for int16. + return np.int16( + range(semaphore_id, semaphore_id + num_semaphores) + ).reshape(shape) + + +TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = { + v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)} +TPU_MEMORY_SPACE_NAMES = { + i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)} + +# Default to VMEM when no memory space is specified. +TPU_MEMORY_SPACE_IDXS[None] = ( + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM]) + +def get_barrier_semaphore(device_id, collective_id): + device_id = tuple(map(int, device_id)) + collective_id = int(collective_id) + + # TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict + # with IDs for regular or DMA semaphores. (For example, store them in a + # different table.) + shared_memory = _get_shared_memory() + with shared_memory.lock: + semaphore_id = collective_id + if not semaphore_id in shared_memory.sem: + shared_memory.sem[semaphore_id] = Semaphore() + + return np.int16(semaphore_id) + +def _transform_slice_or_index(slice_or_idx): + if isinstance(slice_or_idx, int): + return slice_or_idx + else: + start, size, stride = ( + slice_or_idx.start, slice_or_idx.size, slice_or_idx.stride) + return slice(start, start + size * stride, stride) + +def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): + ret = [] + i = 0 + j = 0 + while True: + if i == len(slice_or_idx1): + ret.extend(slice_or_idx2[j:]) + return tuple(ret) + elif j == len(slice_or_idx2): + ret.extend(slice_or_idx1[i:]) + return tuple(ret) + elif isinstance(slice_or_idx1[i], int): + ret.append(slice_or_idx1[i]) + i += 1 + elif isinstance(slice_or_idx2[j], int): + ret.append(slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step) + i += 1 + j += 1 + else: + ret.append(slice( + slice_or_idx1[i].start + slice_or_idx2[j].start * slice_or_idx1[i].step, + slice_or_idx1[i].start + slice_or_idx2[j].stop * slice_or_idx1[i].step, + slice_or_idx1[i].step * slice_or_idx2[j].step + )) + i += 1 + j += 1 + +def _to_range(transforms) -> tuple[slice | int, ...]: + ret = () + for transform in transforms: + # For now, assume only NDIndexer transforms. + ret = _compose_slice_or_index( + ret, tuple(_transform_slice_or_index(i) for i in transform.indices)) + return ret + +def get(device_id, memory_space, buffer_id, transforms): + device_id = tuple(int(x) for x in device_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + try: + transforms = jax.tree.map(int, transforms) + except: + raise ValueError('Advanced indexers are not supported on TPU') + + shared_memory = _get_shared_memory() + with shared_memory.lock: + return shared_memory.mem[(memory_space, buffer_id, device_id)][ + _to_range(transforms) + ].copy() + +def store(device_id, memory_space, buffer_id, transforms, val): + device_id = tuple(int(x) for x in device_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + try: + transforms = jax.tree.map(int, transforms) + except: + raise ValueError('Advanced indexers are not supported on TPU') + val = np.array(val) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + if transforms: + shared_memory.mem[(memory_space, buffer_id, device_id)][ + _to_range(transforms) + ] = val + else: + shared_memory.mem[(memory_space, buffer_id, device_id)] = val + +def swap(device_id, memory_space, buffer_id, transforms, val): + device_id = tuple(int(x) for x in device_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + try: + transforms = jax.tree.map(int, transforms) + except: + raise ValueError('Advanced indexers are not supported on TPU') + val = np.array(val) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + result = shared_memory.mem[(memory_space, buffer_id, device_id)][ + _to_range(transforms) + ].copy() + shared_memory.mem[(memory_space, buffer_id, device_id)][ + _to_range(transforms) + ] = val + + return np.array(result) + +def execute_dma(src, dst, send_sem, recv_sem): + # NOTE: `src` is a list of arguments for `get` (device_id, memory_space, + # buffer_id, transforms) and `dst` is a list of arguments for `store` + # (dst_device_id, dst_memory_space, dst_id, dst_transforms). + # + # TODO(jburnim): Clean this up. + + # Do the read. + data = get(*src) + data_size = data.itemsize * data.size + + # Signal the send semaphore. + if send_sem is not None: + send_sem.signal(data_size, device_id=src[0]) + + # Do the write. + store(*dst, data) + + # Signal the receive semaphore. + recv_sem.signal(data_size, device_id=dst[0]) + +def print_memory(device_id): + device_id = tuple(map(int, device_id)) + if all(d == 0 for d in device_id): + shared_memory = _get_shared_memory() + with shared_memory.lock: + print(shared_memory.mem) + +def dma_start(device_id, src_memory_space, src_id, src_transforms, + dst_memory_space, dst_id, dst_transforms, + dst_sem, + src_sem, + dst_device_id): + device_id = tuple(int(x) for x in device_id) + src_memory_space, src_id = int(src_memory_space), int(src_id) + src_transforms = jax.tree.map(int, src_transforms) + dst_memory_space, dst_id = int(dst_memory_space), int(dst_id) + dst_transforms = jax.tree.map(int, dst_transforms) + dst_sem = int(dst_sem) + if src_sem is not None: + src_sem = int(src_sem) + if dst_device_id is not None: + dst_device_id = tuple(int(x) for x in dst_device_id) + else: + dst_device_id = device_id + + shared_memory = _get_shared_memory() + with shared_memory.lock: + dst_sem = shared_memory.sem[dst_sem] + if src_sem is not None: + src_sem = shared_memory.sem[src_sem] + + # For now, just execute the DMA immediately. + # TODO(jburnim): Execute DMAs asynchronously. + execute_dma( + (device_id, src_memory_space, src_id, src_transforms), + (dst_device_id, dst_memory_space, dst_id, dst_transforms), + src_sem, + dst_sem) + +def dma_wait(device_id, sem, size): + device_id = tuple(int(x) for x in device_id) + sem = int(sem) + size = int(size) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + sem = shared_memory.sem[sem] + sem.wait(size, device_id) + +def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): + device_id = tuple(map(int, device_id)) + sem = int(sem) + inc = int(inc) + target_device_id = tuple(map(int, target_device_id)) + + if target_core_index is not None: + raise NotImplementedError() + + shared_memory = _get_shared_memory() + with shared_memory.lock: + sem = shared_memory.sem[sem] + sem.signal(inc, target_device_id) + +def semaphore_wait(device_id, sem, value): + device_id = tuple(map(int, device_id)) + sem = int(sem) + value = int(value) + + shared_memory = _get_shared_memory() + with shared_memory.lock: + sem = shared_memory.sem[sem] + sem.wait(value, device_id) + +def _compute_transformed_shape_and_dtype(shape, dtype, transforms): + for transform in transforms: + if transform is None: + continue + shape = transform.transform_shape(shape) + dtype = transform.transform_dtype(dtype) + return shape, dtype + +@lu.cache +def _to_jaxpr(flat_fun, in_avals): + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts) + return new_jaxpr + +def _is_any(memory_space): + return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or + (memory_space == pallas_core.MemorySpace.ANY)) + +def _interpret_jaxpr(jaxpr, *args, compiler_params): + env = {} + + def read(var): + if isinstance(var, jax_core.Literal): + return var.val + else: + return env[var] + + def write(var, value): + env[var] = value + + jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) + + # Get the mesh coordinates. + device_coords = tuple( + lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes) + # TODO(jburnim): Convert to a single integer device ID. + # TODO(jburnim): Pass the device ID around, instead of re-fetching/computing + # it for each sub-jaxpr. + + # TODO(jburnim): Clean up and finish this evaluation loop. For example: + # - Handle missing Pallas primitives, like masked_load. + # - Replace the big if-statement with a dictionary of rules. + # - Handle other higher-order primitives? + # - Megacore. + for eqn in jaxpr.eqns: + prim = eqn.primitive + invals = jax.util.safe_map(read, eqn.invars) + + if prim is primitives.load_p: + raise NotImplementedError() + + elif prim is primitives.swap_p: + raise NotImplementedError() + + elif prim is lax.cond_p: + def _make_branch(jaxpr): + return lambda *args: _interpret_jaxpr( + jaxpr, *args, compiler_params=compiler_params) + out = lax.switch( + invals[0], + [_make_branch(branch_jaxpr.jaxpr) + for branch_jaxpr in eqn.params['branches']], + *invals[1:]) + + elif prim is lax.scan_p: + consts, init_carry, xs = split_list( + invals, [eqn.params['num_consts'], eqn.params['num_carry']]) + def _scan_body(c, a): + return split_list( + _interpret_jaxpr(eqn.params['jaxpr'].jaxpr, *consts, *c, *a, + compiler_params=compiler_params), + [eqn.params['num_carry']]) + carry, out = lax.scan(_scan_body, init_carry, xs=xs, + length=eqn.params.get('length', None)) + out = carry + out + + elif prim is lax.while_p: + cond_consts, body_consts, init_vals = split_list( + invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']]) + out = lax.while_loop( + lambda args: _interpret_jaxpr(eqn.params['cond_jaxpr'].jaxpr, + *cond_consts, *args, + compiler_params=compiler_params)[0], + lambda args: _interpret_jaxpr(eqn.params['body_jaxpr'].jaxpr, + *body_consts, *args, + compiler_params=compiler_params), + init_vals) + + elif prim is pjit.pjit_p: + pjit_jaxpr = eqn.params['jaxpr'] + def f(*args): + return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args, + compiler_params=compiler_params) + in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) + new_jaxpr = _to_jaxpr(lu.wrap_init(f), in_avals) + out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) + + elif prim is primitives.run_scoped_p: + # Allocate a buffer or semaphore for each element of + # eqn.params['jaxpr'].invars . + allocs = [] + for v in eqn.params['jaxpr'].invars: + if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + allocs.append(callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), + device_coords, + v.aval.shape, + ordered=True)) + else: + allocs.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_coords, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + primitives.uninitialized_value(v.aval.shape, v.aval.dtype), + ordered=True)) + + out = _interpret_jaxpr(eqn.params['jaxpr'], *invals, *allocs, + compiler_params=compiler_params) + + for a in allocs: + if isinstance(a, tuple): + callback.io_callback( + _deallocate_buffer, + None, + device_coords, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) + else: + # TODO(jburnim): Delete semaphores. + # callback.io_callback( + # _deallocate_semaphores, + # None, + # device_coords, + # a, + # ordered=True) + pass + + elif prim is state_primitives.get_p: + out = callback.io_callback( + get, + eqn.outvars[0].aval, + device_coords, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[1:]), + ordered=True) + + elif prim is state_primitives.swap_p: + out = callback.io_callback( + swap, + eqn.outvars[0].aval, + device_coords, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[2:]), + invals[1], + ordered=True) + + elif prim is mosaic_primitives.dma_start_p: + (src, src_transforms, + dst, dst_transforms, + dst_sem, dst_sem_transforms, + src_sem, src_sem_transforms, + device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + (orig_src_ref, _, orig_dst_ref, *_ + ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) + callback.io_callback( + dma_start, + (), + device_coords, + TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space], + src, src_transforms, + TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space], + dst, dst_transforms, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + state_discharge.transform_array(src_sem, src_sem_transforms), + device_id, + ordered=True) + out = [] + + elif prim is mosaic_primitives.dma_wait_p: + (src, src_transforms, + dst, dst_transforms, + dst_sem, dst_sem_transforms, + src_sem, src_sem_transforms, + device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + read_shape, read_dtype = _compute_transformed_shape_and_dtype( + eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) + callback.io_callback( + dma_wait, + (), + device_coords, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + math.prod(read_shape) * read_dtype.itemsize, + ordered=True) + out = [] + + elif prim is mosaic_primitives.get_barrier_semaphore_p: + out = callback.io_callback( + get_barrier_semaphore, + jax.ShapeDtypeStruct((), jnp.int16), + device_coords, + compiler_params['mosaic']['collective_id'], + ordered=True) + + elif prim is mosaic_primitives.semaphore_signal_p: + sem, sem_transforms, inc, device_id, core_index = ( + jax.tree.unflatten(eqn.params['args_tree'], invals)) + callback.io_callback( + semaphore_signal, + (), + device_coords, + state_discharge.transform_array(sem, sem_transforms), + inc, + device_id, + core_index, + ordered=True) + out = [] + + elif prim is mosaic_primitives.semaphore_wait_p: + sem, sem_transforms, value = ( + jax.tree.unflatten(eqn.params['args_tree'], invals)) + callback.io_callback( + semaphore_wait, + (), + device_coords, + state_discharge.transform_array(sem, sem_transforms), + value, + ordered=True) + out = [] + + else: + out = prim.bind(*invals, **eqn.params) + + out = out if prim.multiple_results else [out] + jax.util.safe_map(write, eqn.outvars, out) + + return jax.util.safe_map(read, jaxpr.outvars) + +def _initialize_output_vals( + block_mappings_output: Iterable[BlockMapping], + input_args, input_output_aliases) -> Sequence[jax.Array]: + oi_map = {v: k for k, v in input_output_aliases} + output_vals = [] + for i, bm in enumerate(block_mappings_output): + if i in oi_map: + output_vals.append(input_args[oi_map[i]]) + else: + output_vals.append(primitives.uninitialized_value( + bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype)) + return output_vals + +def _compute_start_indices(block_mapping, loop_idx, *args): + block_indices = ( + jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) + if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): + ret = tuple(i if b is pallas_core.mapped else b * i + for b, i in zip(block_mapping.block_shape, block_indices)) + elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): + ret = block_indices + else: + raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") + return ret + +def _get_next_indices(grid, indices): + next_indices = [] + carry = True + for dim_size, index in reversed(list(zip(grid, indices))): + i = jnp.where(carry, index + 1, index) + carry = dim_size == i + next_indices.append(jnp.where(carry, 0, i)) + return tuple(reversed(next_indices)) + +def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): + start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx) + output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) + squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, + dtype=np.bool_)]) + return lax.squeeze(output, squeeze_dims) + +def get_interpret_effects(): + return {callback._OrderedIOEffect} + +def interpret_pallas_call( + *args, + jaxpr: jax_core.Jaxpr, + name_and_src_info: pallas_core.NameAndSrcInfo, + debug: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: GridMapping, + compiler_params: Any, + cost_estimate: CostEstimate, + out_avals: tuple[jax_core.AbstractValue, ...], +): + del debug, cost_estimate, out_avals + + # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) + dynamic_grid_args, scalars, input_args = split_list( # type: ignore + args, + [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands], + ) + # TODO(jburnim): Support dynamic grid sizes? + grid = grid_mapping.static_grid + + device_coords = tuple( + lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes) + + # Allocate buffers in HBM for outputs. + io_alias_map = dict(input_output_aliases) + output_buffer_ids = [] + output_vals = _initialize_output_vals( + grid_mapping.block_mappings_output, args, input_output_aliases) + for out_val in output_vals: + output_buffer_ids.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_coords, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + out_val, + ordered=True)) + + # Allocate buffers for all kernel arguments (e.g., scalars, inputs, + # outputs, scratch). + kernel_buffer_ids = [] + for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + kernel_buffer_ids.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_coords, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], + val, + ordered=True)) + for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): + output_idx = i - grid_mapping.num_inputs + is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) + if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + kernel_buffer_ids.append(callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), + device_coords, + var.aval.shape, + ordered=True)) + elif is_output and _is_any(var.aval.memory_space): + # Don't allocate a buffer -- use the already-allocated HBM output buffer. + # + # TODO(jburnim): For kernel args in HBM, check that block shape is the + # same as for the corresponding pallas_call input, and that the index_map + # is trivial. + kernel_buffer_ids.append(output_buffer_ids[output_idx]) + elif (i < grid_mapping.num_inputs) and (i in io_alias_map): + # Instead of allocating a new buffer, use the already-allocated + # HBM output buffer. + assert _is_any(var.aval.memory_space) + kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) + else: + kernel_buffer_ids.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_coords, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + primitives.uninitialized_value(var.aval.shape, var.aval.dtype), + ordered=True)) + + num_inputs = grid_mapping.num_inputs + _, input_ids, kernel_output_ids, _ = split_list( + kernel_buffer_ids, + [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) + input_vars, output_vars = split_list( + jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) + + # For kernel inputs that are in HBM, we populate the buffer once before + # any kernel invocations. + for buffer_id, var, val in zip(input_ids, input_vars, input_args): + if not _is_any(var.aval.memory_space): + continue + if val.shape != var.aval.shape: + # TODO(jburnim): Also check that the index_map is trivial. + raise ValueError() + callback.io_callback( + store, + (), + device_coords, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + buffer_id, + (), + val, + ordered=True) + + is_indexing_dim = [ + tuple(b is pallas_core.mapped for b in bm.block_shape) + for bm in grid_mapping.block_mappings + ] + block_shapes = [ + tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + ] + scalar_ids, in_out_ids, scratch_ids = split_list( + kernel_buffer_ids, + [grid_mapping.num_index_operands, len(grid_mapping.block_mappings)]) + + if grid: + num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type] + else: + # Base case is always one iteration when grid is () + num_iterations = 1 + + def body(carry): + # The loop carry: (i, loop_idx) -- + # - i:int32 is the interation index + # - loop_idx: tuple[int32] are the program ids for each grid axis + i, loop_idx = carry + + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + + with pallas_core.grid_env(local_grid_env): + # Copy slices of the input to the kernel buffers. + # + # TODO(jburnim): Only copy slices when the index mapping has changed? + start_indices = [_compute_start_indices(bm, loop_idx, *scalars) + for bm in grid_mapping.block_mappings] + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j], + input_args[j], is_indexing_dim[j]) + assert(sliced_val.shape == var.aval.shape) + callback.io_callback( + store, + (), + device_coords, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + input_ids[j], + (), + sliced_val, + ordered=True) + + # Invoke the kernel. + _interpret_jaxpr(jaxpr, *kernel_buffer_ids, + compiler_params=compiler_params) + + # Copy from the kernel buffers to slices of the output in HBM. + # + # TODO(jburnim): Only copy if the index mapping will change in the + # next iteration (or if this is the last iteration)? + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + continue + kernel_output_val = callback.io_callback( + get, + var.aval, + device_coords, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + kernel_output_ids[j], + (), + ordered=True) + transform = indexing.NDIndexer( + indices=tuple(indexing.ds(st, sz) + for st, sz in zip(start_indices[num_inputs + j], + block_shapes[num_inputs + j])), + shape=output_vals[j].shape, + int_indexer_shape=()) + callback.io_callback( + store, + (), + device_coords, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + output_buffer_ids[j], + (transform,), + kernel_output_val, + ordered=True) + + return i + 1, _get_next_indices(grid, loop_idx) + + # TODO(jburnim): Handle parallel grid dimensions + megacore. + _ = lax.while_loop( + lambda carry: carry[0] < num_iterations, + body, + (jnp.int32(0), (jnp.int32(0),) * len(grid)) + ) + + # Read the output from the allocated output buffers. + ret = [ + callback.io_callback( + get, + val, + device_coords, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + output_buffer_id, + (), + ordered=True) + for val, output_buffer_id in zip(output_vals, output_buffer_ids) + ] + + for buffer_id in output_buffer_ids: + callback.io_callback( + _deallocate_buffer, + (), + device_coords, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + buffer_id, + ordered=True) + for buffer_id, var in zip(kernel_buffer_ids, jaxpr.invars): + if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + pass + else: + callback.io_callback( + _deallocate_buffer, + (), + device_coords, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + buffer_id, + ordered=True) + + return ret diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b53e7412f..ccef4152f 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -17,7 +17,9 @@ from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses +import enum from functools import partial, reduce +import types from typing import Any, Literal import jax @@ -82,19 +84,31 @@ pallas_call_p.def_impl(_pallas_call_impl) def _pallas_call_abstract_eval( - *avals, out_avals: tuple[jax_core.AbstractValue, ...], **_ + *avals, + out_avals: tuple[jax_core.AbstractValue, ...], + interpret, + backend, + **params ): del avals + + if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + # Report effects that will be introduced when running/lowering + # mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call . + effs = mosaic_tpu_interpret.get_interpret_effects() + else: + effs = jax_core.no_effects + # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. return [ jax_core.ShapedArray(a.shape, a.dtype, a.weak_type) if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) else a for a in out_avals - ] + ], effs -pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) +pallas_call_p.def_effectful_abstract_eval(_pallas_call_abstract_eval) def _pallas_call_jvp_rule( @@ -1230,9 +1244,12 @@ def _pallas_call_lowering( if params['jaxpr'].constvars: raise ValueError('Cannot lower a pallas_call with constants.') if interpret: - impl = partial(hlo_interpreter.pallas_call_hlo_interpret, - backend=backend, - **params) + if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + impl = partial(mosaic_tpu_interpret.interpret_pallas_call, **params) + else: + impl = partial(hlo_interpreter.pallas_call_hlo_interpret, + backend=backend, + **params) return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes) def cpu_lowering(ctx: mlir.LoweringRuleContext, @@ -1681,3 +1698,10 @@ try: from jax._src.pallas.mosaic import pallas_call_registration as mosaic_tpu_backend except ImportError: mosaic_tpu_backend = None # type: ignore + +try: + from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret +except ImportError: + mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore + TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)), + ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index f21f0a8e4..11beaeae0 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -426,6 +426,32 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "tpu_pallas_interpret_test", + srcs = [ + "tpu_pallas_interpret_test.py", + ], + disable_configs = ["cpu_shardy"], + enable_backends = ["cpu"], + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "tpu_pallas_interpret_distributed_test", + srcs = [ + "tpu_pallas_interpret_distributed_test.py", + ], + disable_configs = ["cpu_shardy"], + enable_backends = ["cpu"], + deps = [ + "//third_party/py/jax:pallas", + "//third_party/py/jax:pallas_tpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py new file mode 100644 index 000000000..059650c3f --- /dev/null +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -0,0 +1,995 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TPU-specific interpret mode. + +To work around https://github.com/jax-ml/jax/issues/25671 , this file +contains only tests that use shard_map. +""" + +from absl.testing import absltest +import numpy as np + +import jax +from jax import lax +from jax._src import test_util as jtu +import jax._src.pallas.mosaic.interpret as mosaic_interpret +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + +P = jax.sharding.PartitionSpec + + +class InterpretDistributedTest(jtu.JaxTestCase): + + def test_right_permute_example(self): + num_devices = jax.device_count() + if num_devices < 4: + self.skipTest(f'requires at least 4 devices, found {num_devices}') + partition = P(None, 'x') + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + # Create an input array that shards the last dimension across + # all devices. + input_arr = jax.random.uniform( + jax.random.key(0), (8, 128 * num_devices), dtype=jnp.float32) + input_arr = jax.device_put(input_arr, sharding) + + def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): + my_id = lax.axis_index('x') + left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices)) + right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices)) + + barrier_sem = pltpu.get_barrier_semaphore() + def _body(ijk): + i, (j, k) = ijk + lax.cond( + (i == 0) | (j == 0), + lambda: pltpu.semaphore_signal( + barrier_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH), + lambda: pltpu.semaphore_signal( + barrier_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH)) + return (i + 1, (j + 1, k + 1)) + lax.while_loop(lambda ijk: ijk[0] < 2, _body, (0, (0, 0))) + pltpu.semaphore_wait(barrier_sem, 2) + + def _body2(i, a): + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + remote_copy_op.wait() + + return i + 1, a + 1 + _ = lax.scan(_body2, 0, jnp.arange(4.0), unroll=2) + + out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=( + # We allocate DMA semaphores in scratch memory. + [pltpu.SemaphoreType.DMA] * 2 + ), + ) + right_permute = pl.pallas_call( + right_permute_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=pltpu.TPUCompilerParams(collective_id=13), + interpret=mosaic_interpret.TPUInterpretParams(), + ) + # Wrap the kernel within a shard_map to call. + pallas_result = jax.jit( + shard_map.shard_map( + right_permute, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False, + ) + )(input_arr) + + # Compare Pallas result to XLA shard_map result. + perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices)) + xla_result = jax.jit( + shard_map.shard_map( + lambda x: lax.ppermute(x, 'x', perm), + mesh=mesh, in_specs=partition, out_specs=partition) + )(input_arr) + + np.testing.assert_allclose(xla_result, pallas_result) + + + def test_all_gather_example(self): + num_devices = jax.device_count() + if num_devices < 4: + self.skipTest(f'requires at least 4 devices, found {num_devices}') + partition = P('x', None) + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + # Create an input array that shards the first dimension across + # all devices. + input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128)) + input_arr = jax.device_put(input_arr, sharding) + + def all_gather_kernel(input_ref, + output_ref, + local_copy_sem, + send_sem, + recv_sems): + outer_step = pl.program_id(0) + my_id = lax.axis_index('x') + left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices)) + right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices)) + copy_slot = my_id - outer_step + copy_slot = lax.rem(copy_slot + num_devices, jnp.int32(num_devices)) + + @pl.when(outer_step == 0) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + local_copy_op = pltpu.make_async_copy( + src_ref=input_ref, + dst_ref=output_ref.at[my_id], + sem=local_copy_sem, + ) + local_copy_op.start() + local_copy_op.wait() + + # Copy to our right neighbor. + # Note that we will also be receiving data from our left neighbor, + # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact + # that the indices do not need to be symmetric between remote DMAs. + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=output_ref.at[copy_slot], + dst_ref=output_ref.at[copy_slot], + send_sem=send_sem, + recv_sem=recv_sems.at[outer_step], + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + remote_copy_op.wait() + + out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=( + # DMA semaphores are allocated in scratch memory. + # We allocated one semaphore for a local HBM-VMEM copy, + # and one for the remote send semaphore. + [pltpu.SemaphoreType.DMA] * 2 + # We additionally allocate one receive semaphore per device. + # This is to avoid situations where we have multiple + # DMAs in flight, as we do not want to share a receive + # semaphore between the DMAs. + + [pltpu.SemaphoreType.DMA((num_devices-1,))] + ), + grid=(num_devices-1,) + ) + + all_gather = pl.pallas_call( + all_gather_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=mosaic_interpret.TPUInterpretParams(), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + ) + + # Wrap the kernel within a shard_map to call. + pallas_result = jax.jit( + shard_map.shard_map( + all_gather, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False + ) + )(input_arr) + + # Compare Pallas result to XLA shard_map result. + xla_result = jax.jit( + shard_map.shard_map( + lambda x: lax.all_gather(x, 'x'), + mesh=mesh, in_specs=partition, out_specs=partition + ) + )(input_arr) + + np.testing.assert_allclose(xla_result, pallas_result) + + def test_all_reduce_sum_example(self): + num_devices = jax.device_count() + if num_devices < 4: + self.skipTest(f'requires at least 4 devices, found {num_devices}') + partition = P(None, 'x') + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + input_arr = jax.random.uniform( + jax.random.key(0), shape=(8, 128 * num_devices)) + input_arr = jax.device_put(input_arr, sharding) + + def all_reduce_kernel( + x_ref, + o_ref, + hbm_scratch, + copy_sem, + remote_recv_sem, + remote_send_sem, + capacity_sem, + receive_scratch, + ): + outer_step = pl.program_id(0) + working_slot = lax.rem(outer_step, jnp.int32(2)) + receiving_slot = 1 - working_slot + + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices)) + left_neighbor = lax.rem(my_id - 1 + num_devices, jnp.int32(num_devices)) + + @pl.when(outer_step == 0) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + # Initialize o_ref, acc_scratch, and hbm_scratch. + o_ref[...] = jnp.zeros_like(o_ref) + receive_scratch[...] = jnp.zeros_like(receive_scratch) + initial_copy = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=hbm_scratch.at[working_slot], + send_sem=remote_send_sem, + recv_sem=remote_recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + initial_copy.start() + initial_copy.wait() + + # Signal to our left neighbor that we are ready to receive. + # Without this signal, our left neighbor can be >=1 iteration ahead, + # meaning it could write into our working slot. + pltpu.semaphore_signal( + capacity_sem, + inc=1, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # Copy the partial result our left neighbor sent to us into VMEM for + # computation. + local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot], + dst_ref=receive_scratch, + sem=copy_sem, + ) + local_copy.start() + + # Block until our right neighbor is ready to receive. + pltpu.semaphore_wait(capacity_sem, 1) + # Pass the value to our right neighbor. + remote_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot], + dst_ref=hbm_scratch.at[receiving_slot], + send_sem=remote_send_sem, + recv_sem=remote_recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy.start() + # Finish local copy and accumulate while remote_copy is happening. + local_copy.wait() + o_ref[...] += receive_scratch[...] + # Block until remote copy finishes. + remote_copy.wait() + + out_shape = ( + jax.ShapeDtypeStruct((8, 128), jnp.float32), + # We allocate the double-buffer as a Pallas output so that it is + # resident in HBM. + jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch + ) + + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + # Our input lives in VMEM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=[ + # Our output lives in VMEM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + # Our double-buffer lives in HBM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices,), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 3 + + [pltpu.SemaphoreType.REGULAR] # capacity_sem + + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch + ), + ) + + kernel = pl.pallas_call( + all_reduce_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=mosaic_interpret.TPUInterpretParams(), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + ) + + pallas_result = jax.jit( + shard_map.shard_map( + kernel, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False, + ) + )(input_arr) + pallas_result = jax.block_until_ready(pallas_result)[0] + + def lax_sum(x): + return lax.psum(x, 'x') + + xla_result = jax.jit( + shard_map.shard_map( + lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x') + ) + )(input_arr) + + np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + + + def test_reduce_scatter_sum_example(self): + num_devices = jax.device_count() + if num_devices < 4: + self.skipTest(f'requires at least 4 devices, found {num_devices}') + partition = P(None, 'x') + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + # We need a block size of (16, 128) to ensure that a half-slice is at least + # of size (8, 128), which is the size of a VREG. This makes tiling easier + # for the compiler. + block_size = (16, 128) + input_arr = jax.random.uniform( + jax.random.key(0), + shape=(block_size[0] * num_devices, block_size[1] * num_devices), + dtype=jnp.float32, + ) + input_arr = jax.device_put(input_arr, sharding) + + LEFT = 0 + RIGHT = 1 + + def mod(x, n): + return lax.rem(x + n, n) + + def signal(left_or_right, semaphore): + my_id = lax.axis_index('x') + if left_or_right == LEFT: + neighbor = mod(my_id - 1, jnp.int32(num_devices)) + else: + neighbor = mod(my_id + 1, jnp.int32(num_devices)) + pltpu.semaphore_signal( + semaphore, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + def reduce_scatter_kernel( + x_ref, + o_ref, + hbm_scratch, + local_copy_sem, + left_recv_sem, + left_send_sem, + right_recv_sem, + right_send_sem, + left_capacity_sem, + right_capacity_sem, + accum_scratch, + ): + outer_step = pl.program_id(0) + phase = pl.program_id(1) + is_start = jnp.logical_and(outer_step == 0, phase == 0) + last_iteration = outer_step == pl.num_programs(0) - 1 + + working_slot = lax.rem(outer_step, jnp.int32(2)) + receiving_slot = 1 - working_slot + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, jnp.int32(num_devices)) + left_neighbor = mod(my_id - 1, jnp.int32(num_devices)) + + left_copy_device = mod(my_id + outer_step + 1, jnp.int32(num_devices)) + right_copy_device = mod(my_id - outer_step - 1, jnp.int32(num_devices)) + # Slices can be specified using pl.ds(start, size) + left_copy_slice = pl.ds(0, block_size[0] // 2) + right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2) + current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2) + + @pl.when(is_start) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + initial_left_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, left_copy_slice], + dst_ref=hbm_scratch.at[working_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + initial_right_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + left_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot, left_copy_slice], + dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + right_copy = pltpu.make_async_remote_copy( + # Note: Right copy is flipped with regards to slots since we are copying + # to the next outer_step iteration. + src_ref=hbm_scratch.at[receiving_slot, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # --- Prologue --- + @pl.when(is_start) + def _(): + # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies. + o_ref[...] = jnp.zeros_like(o_ref[...]) + accum_scratch[...] = jnp.zeros_like(accum_scratch[...]) + + initial_left_copy.start() + initial_left_copy.wait() + initial_right_copy.start() + + # We tell our left neighbor that it is allowed to send to the right. + # (and vice versa for right neighbor) + signal(LEFT, right_capacity_sem) + signal(RIGHT, left_capacity_sem) + + # --- Body --- + # At the beginning of our kernel body, we start a DMA which copies + # the result we computed in the previous phase to our neighbor. + # This allows us to overlap the communication of sending our previous phase + # with the computation for the current phase. + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + # We block here until our right neighbor tells use we can send to + # the right. + pltpu.semaphore_wait(right_capacity_sem, 1) + right_copy.start() + + @pl.when(phase == RIGHT) + def _(): + # We block here until our left neighbor tells use we can send to + # the left. + pltpu.semaphore_wait(left_capacity_sem, 1) + left_copy.start() + + local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=accum_scratch, + sem=local_copy_sem, + ) + local_copy.start() + local_copy.wait() + + @pl.when(~last_iteration) + def _(): + @pl.when(phase == LEFT) + def _(): + accum_scratch[...] += x_ref[left_copy_device, left_copy_slice] + + @pl.when(phase == RIGHT) + def _(): + accum_scratch[...] += x_ref[right_copy_device, right_copy_slice] + + local_copy = pltpu.make_async_copy( + src_ref=accum_scratch, + dst_ref=hbm_scratch.at[working_slot, current_phase_slice], + sem=local_copy_sem, + ) + local_copy.start() + local_copy.wait() + + @pl.when(is_start) + def _(): + initial_right_copy.wait() + + # At the end of our kernel body, we wait on the DMA of the previous phase + # to make sure the results are ready for the next phase. + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + right_copy.wait() + signal(LEFT, right_capacity_sem) + + @pl.when(phase == RIGHT) + def _(): + left_copy.wait() + signal(RIGHT, left_capacity_sem) + + # --- Epilogue --- + # Store result on last iteration. + @pl.when(last_iteration) + def _(): + # Clean up semaphores so that they exit with a value of 0. + @pl.when(phase == LEFT) + def _(): + o_ref[left_copy_slice, ...] = accum_scratch[...] + pltpu.semaphore_wait(right_capacity_sem, 1) + + @pl.when(phase == RIGHT) + def _(): + o_ref[right_copy_slice, ...] = accum_scratch[...] + pltpu.semaphore_wait(left_capacity_sem, 1) + + out_shape = ( + jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output + # Shape: [working/recv, block[0], block[1]] + jax.ShapeDtypeStruct( + (2, block_size[0], block_size[1]), jnp.float32 + ), # hbm_scratch + ) + + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices, 2), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 5 + + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores + + [ + pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32) + ] # accum_scratch + ), + ) + + def pallas_reduce_scatter(input_arr): + input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1]) + return pl.pallas_call( + reduce_scatter_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=mosaic_interpret.TPUInterpretParams(), + compiler_params=pltpu.TPUCompilerParams(collective_id=7), + )(input_arr)[0] + + pallas_result = jax.jit( + shard_map.shard_map( + pallas_reduce_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + check_rep=False, + ) + )(input_arr) + pallas_result = jax.block_until_ready(pallas_result) + + # Compare our result to XLA. + def lax_reduce_sum_scatter(x): + x = x.reshape(num_devices, block_size[0], block_size[1]) + return lax.psum_scatter(x, 'x') + + xla_result = jax.jit( + shard_map.shard_map( + lax_reduce_sum_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + ) + )(input_arr) + + np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + + def test_reduce_scatter_sum_with_emit_pipeline_example(self): + self.skipTest('requires a patched pallas.emit_pipeline to specify/fake ' + 'the TPU generation') + if jax.config.jax_enable_x64: + self.skipTest('pallas.emit_pipeline + x64 is not currently supported') + num_devices = jax.device_count() + if num_devices < 4: + self.skipTest(f'requires at least 4 devices, found {num_devices}') + partition = P(None, 'x') + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + # We pick a large outer kernel block size that we do not want to place + # in VMEM. For pedagogical purposes we use (4096, 4096), although in + # principle this can be much larger. + outer_block_size = (512, 512) + # We pick a smaller VMEM block size for the inner kernel. + inner_block_size = (128, 128) + input_arr = jax.random.uniform( + jax.random.key(0), + shape=( + outer_block_size[0] * num_devices, + outer_block_size[1] * num_devices, + ), + ) + input_arr = jax.device_put(input_arr, sharding) + + inner_grid = ( + outer_block_size[0] // inner_block_size[0] // 2, + outer_block_size[1] // inner_block_size[1], + ) + inner_block_spec = pl.BlockSpec( + index_map=lambda i, j: (i, j), + block_shape=inner_block_size, + memory_space=pltpu.TPUMemorySpace.ANY, + ) + + LEFT = 0 + RIGHT = 1 + + def mod(x, n): + return lax.rem(x + n, n) + + def signal(left_or_right, semaphore): + my_id = lax.axis_index('x') + if left_or_right == LEFT: + neighbor = mod(my_id - 1, num_devices) + else: + neighbor = mod(my_id + 1, num_devices) + pltpu.semaphore_signal( + semaphore, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + def reduce_scatter_kernel( + x_ref, + o_ref, + hbm_scratch, + left_recv_sem, + left_send_sem, + copy_sem, + right_recv_sem, + right_send_sem, + left_capacity_sem, + right_capacity_sem, + ): + outer_step = pl.program_id(0) + phase = pl.program_id(1) + is_start = jnp.logical_and(outer_step == 0, phase == 0) + last_iteration = outer_step == pl.num_programs(0) - 1 + + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + left_copy_device = mod(my_id + outer_step + 1, num_devices) + right_copy_device = mod(my_id - outer_step - 1, num_devices) + left_copy_slice = pl.ds(0, outer_block_size[0] // 2) + right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2) + current_phase_slice = pl.ds( + phase * (outer_block_size[0] // 2), outer_block_size[0] // 2 + ) + + initial_left_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, left_copy_slice], + dst_ref=hbm_scratch.at[working_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + initial_right_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + left_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot, left_copy_slice], + dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + right_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[receiving_slot, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # --- Prologue --- + @pl.when(is_start) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + initial_left_copy.start() + initial_left_copy.wait() + initial_right_copy.start() + + # We tell our left neighbor that it is allowed to send to the right. + # (and vice versa for right neighbor) + signal(LEFT, right_capacity_sem) + signal(RIGHT, left_capacity_sem) + + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + # We block here until our right neighbor tells use we can send to + # the right. + pltpu.semaphore_wait(right_capacity_sem, 1) + right_copy.start() + + @pl.when(phase == RIGHT) + def _(): + # We block here until our left neighbor tells use we can send to + # the left. + pltpu.semaphore_wait(left_capacity_sem, 1) + left_copy.start() + + # --- Body --- + def inner_kernel(input_ref, accum_ref): + # We do not explicitly use += because we set should_accumulate_out=True. + accum_ref[...] = input_ref[...] + + accum_pipeline = pltpu.emit_pipeline( + inner_kernel, + in_specs=[inner_block_spec], + out_specs=inner_block_spec, + should_accumulate_out=True, + grid=inner_grid, + ) + + @pl.when(~last_iteration) + def _(): + @pl.when(phase == LEFT) + def _(): + accum_pipeline( + x_ref.at[left_copy_device, left_copy_slice], + hbm_scratch.at[working_slot, left_copy_slice], + ) + + @pl.when(phase == RIGHT) + def _(): + accum_pipeline( + x_ref.at[right_copy_device, right_copy_slice], + hbm_scratch.at[working_slot, right_copy_slice], + ) + + # --- Epilogue --- + @pl.when(is_start) + def _(): + initial_right_copy.wait() + + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + right_copy.wait() + signal(LEFT, right_capacity_sem) + + @pl.when(phase == RIGHT) + def _(): + left_copy.wait() + signal(RIGHT, left_capacity_sem) + + # Store result on last iteration. + @pl.when(last_iteration) + def _(): + output_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=o_ref.at[current_phase_slice], + sem=copy_sem, + ) + output_copy.start() + output_copy.wait() + + # Clean up semaphores so that they exit with a value of 0. + @pl.when(phase == LEFT) + def _(): + pltpu.semaphore_wait(right_capacity_sem, 1) + + @pl.when(phase == RIGHT) + def _(): + pltpu.semaphore_wait(left_capacity_sem, 1) + + + out_shape = ( + jax.ShapeDtypeStruct( + (outer_block_size[0], outer_block_size[1]), jnp.float32 + ), + # Shape: [working/recv, block[0], block[1]] + jax.ShapeDtypeStruct( + (2, outer_block_size[0], outer_block_size[1]), jnp.float32 + ), # hbm_scratch + ) + + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices, 2), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 5 + + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores + ), + ) + + def pallas_reduce_scatter(input_arr): + input_arr = input_arr.reshape( + num_devices, outer_block_size[0], outer_block_size[1] + ) + return pl.pallas_call( + reduce_scatter_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=mosaic_interpret.TPUInterpretParams(), + compiler_params=pltpu.TPUCompilerParams(collective_id=19), + )(input_arr)[0] + + pallas_result = jax.jit( + shard_map.shard_map( + pallas_reduce_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + check_rep=False, + ) + )(input_arr) + pallas_result = jax.block_until_ready(pallas_result) + + def lax_reduce_sum_scatter(x): + x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1]) + return lax.psum_scatter(x, 'x') + + xla_result = jax.jit( + shard_map.shard_map( + lax_reduce_sum_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + ) + )(input_arr) + + np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py new file mode 100644 index 000000000..4524cb649 --- /dev/null +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -0,0 +1,67 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TPU-specific interpret mode. + +To work around https://github.com/jax-ml/jax/issues/25671 , this file +contains only tests that do not use shard_map. +""" + +from absl.testing import absltest +import numpy as np + +import jax +from jax._src import test_util as jtu +import jax._src.pallas.mosaic.interpret as mosaic_interpret +from jax.experimental import pallas as pl + +jax.config.parse_flags_with_absl() + + +class InterpretTest(jtu.JaxTestCase): + + def test_matmul_example(self): + num_devices = jax.device_count() + if num_devices > 1: + # Workaround for https://github.com/jax-ml/jax/issues/25671 + self.skipTest(f'requires 1 device, found {num_devices}') + + def matmul_kernel(x_ref, y_ref, z_ref): + z_ref[...] = x_ref[...] @ y_ref[...] + + @jax.jit + def matmul(x: jax.Array, y: jax.Array): + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), + ), + interpret=mosaic_interpret.TPUInterpretParams(), + )(x, y) + + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k1, (1024, 1024)) + y = jax.random.normal(k2, (1024, 1024)) + z = matmul(x, y) + np.testing.assert_allclose(z, x @ y, atol=1e-4) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())