mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
1652 lines
60 KiB
Python
1652 lines
60 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import collections
|
|
from collections.abc import Iterable, Sequence
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import threading
|
|
from typing import Any, Literal
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import callback
|
|
from jax._src import core as jax_core
|
|
from jax._src.lax.control_flow import for_loop
|
|
from jax._src import linear_util as lu
|
|
from jax._src import source_info_util
|
|
from jax._src.pallas.mosaic import primitives as mosaic_primitives
|
|
from jax._src.pallas.mosaic import core as mosaic_core
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.pallas import primitives
|
|
from jax._src import pjit
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.state import indexing
|
|
from jax._src.state import primitives as state_primitives
|
|
from jax._src.util import (
|
|
safe_map,
|
|
safe_zip,
|
|
split_list
|
|
)
|
|
from jax.interpreters import partial_eval as pe
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
Grid = pallas_core.Grid
|
|
TupleGrid = pallas_core.TupleGrid
|
|
GridSpec = pallas_core.GridSpec
|
|
BlockMapping = pallas_core.BlockMapping
|
|
GridMapping = pallas_core.GridMapping
|
|
BlockSpec = pallas_core.BlockSpec
|
|
BlockSpecTree = pallas_core.BlockSpecTree
|
|
NoBlockSpec = pallas_core.NoBlockSpec
|
|
no_block_spec = pallas_core.no_block_spec
|
|
ScratchShapeTree = pallas_core.ScratchShapeTree
|
|
CostEstimate = pallas_core.CostEstimate
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TPUInterpretParams:
|
|
"""Parameters for Mosaic TPU interpret mode.
|
|
|
|
Attributes:
|
|
dma_execution_mode: If "eager", DMAs are executed as soon as they are
|
|
issued. If "on_wait", DMA reads or writes are only executed when a device
|
|
is waiting on a DMA semaphore that will be signaled when the read or write
|
|
is complete.
|
|
Default: "on_wait".
|
|
detect_races: If True, a dynamic, happens-before race detector will be
|
|
used to detect data races during kernel interpretation. If any races are
|
|
detected, a message will be printed and `races.races_found` will be set
|
|
to True.
|
|
Default: False.
|
|
skip_floating_point_ops: If True, operations that produce only floating
|
|
point values will not be interpreted; instead, their results will be
|
|
replaced with arrays all of `jnp.inf`. Additionaly any floating point
|
|
operands to any operation will be replaced with (arrays of) `jnp.inf`.
|
|
Default: False.
|
|
uninitialized_memory: If "nan", allocated buffers are initialized to
|
|
to contain all NaNs (or to their maximum possible value for integers).
|
|
If "zero", allocated buffers are initialized to all zeros.
|
|
Default: "nan".
|
|
"""
|
|
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
|
|
detect_races: bool = False
|
|
skip_floating_point_ops: bool = False
|
|
uninitialized_memory: Literal["nan", "zero"] = "nan"
|
|
|
|
|
|
VectorClock = np.ndarray
|
|
|
|
# Conceptually, each DMA runs on its own, independent device. Representing
|
|
# this precisely would require vector clocks to have sizes linear in the number
|
|
# of DMAs.
|
|
#
|
|
# Instead, we use approximate vector clocks of fixed size. We assign each DMA
|
|
# a virtual device ID in the range [num_devices + 1, NUM_VIRTUAL_DEVICES] --
|
|
# and each operation of a DMA increments the corresponding coordinate in its
|
|
# vector clock. (So the "virtual" part of a vector clock is effectively
|
|
# counting, for each virtual device, the number of DMAs that happened-before
|
|
# the vector clock and were assigned to that virtual device.)
|
|
#
|
|
# If two approximate clocks are unordered, then their corresponding events are
|
|
# not ordered by the happens-before relation. So this approximation will not
|
|
# introduce any false positives in detecting data races. But we may fail to
|
|
# detect some true data races because there can be cases where two approximate
|
|
# clocks are ordered, and we will treat the corresponding events as ordered
|
|
# by the happens-before relation, but the corresponding events are not
|
|
# actually ordered.
|
|
NUM_VIRTUAL_DEVICES = 32
|
|
|
|
def make_vector_clock(num_devices: int) -> VectorClock:
|
|
del num_devices
|
|
return np.zeros(NUM_VIRTUAL_DEVICES, dtype=np.int32)
|
|
|
|
def copy_vector_clock(x: VectorClock) -> VectorClock:
|
|
if x is None:
|
|
return None
|
|
return x.copy()
|
|
|
|
def update_vector_clock(x: VectorClock, y: VectorClock):
|
|
x[:] = np.maximum(x, y)
|
|
|
|
def lt(x: VectorClock, y: VectorClock) -> bool:
|
|
return bool((x <= y).all() & (x < y).any())
|
|
|
|
def ordered(x: VectorClock, y: VectorClock) -> bool:
|
|
return lt(x, y) | lt(y, x)
|
|
|
|
def inc_vector_clock(x: VectorClock, device_id: int):
|
|
if device_id >= len(x):
|
|
raise ValueError(f'device_id={device_id} is out of range for x={x}')
|
|
assert device_id < len(x)
|
|
x[device_id] += 1
|
|
|
|
|
|
class Semaphore:
|
|
def __init__(self, semaphore_id=None):
|
|
shared_memory = _get_shared_memory()
|
|
|
|
self.id = semaphore_id
|
|
|
|
# TODO(jburnim): Use one Condition variable per device. (Which will be
|
|
# easier to do when we're using single integer device IDs.)
|
|
self.cv = threading.Condition()
|
|
|
|
self.counts = np.zeros(shared_memory.num_devices, dtype=np.int32)
|
|
|
|
self.interpret_params = shared_memory.interpret_params
|
|
if self.interpret_params.detect_races:
|
|
# We associate a vector clock with each count in self.counts. Whenever
|
|
# self.counts[i] is signaled, self.clocks[i] is updated with the vector
|
|
# clock of the signaling device. Whenever device i successfully waits on
|
|
# self.counts[i], the vector clock of device i is updated with
|
|
# self.clocks[i].
|
|
#
|
|
# TODO(jburnim): Model happens-before more precisely for the case where
|
|
# semaphores are over-signaled.
|
|
self.clocks = [None] * shared_memory.num_devices
|
|
|
|
def signal(self, inc, device_id, clock):
|
|
"""Signal the semaphore on `device_id` by `inc`.
|
|
|
|
Args:
|
|
inc: A positive integer. The amount by which to increment the semaphore
|
|
on the target device.
|
|
device_id: The ID of the target device.
|
|
clock: The vector clock of the signaling device at the time of the signal.
|
|
"""
|
|
device_id = int(device_id)
|
|
with self.cv:
|
|
self.counts[device_id] += inc
|
|
if self.interpret_params.detect_races:
|
|
if self.clocks[device_id] is None:
|
|
self.clocks[device_id] = copy_vector_clock(clock)
|
|
else:
|
|
update_vector_clock(self.clocks[device_id], clock)
|
|
self.cv.notify_all()
|
|
|
|
def read(self, device_id):
|
|
with self.cv:
|
|
return self.counts[device_id]
|
|
|
|
def wait(self, value, device_id, *, is_dma=False):
|
|
device_id = int(device_id)
|
|
shared_memory = _get_shared_memory()
|
|
|
|
# TODO(jburnim):
|
|
# - If the count is larger than value, raise an error?
|
|
# - If the count is equal to value, but there DMAs waiting to signal us,
|
|
# raise an error?
|
|
|
|
# Simple implementation for non-DMA semaphores.
|
|
if not is_dma or (self.interpret_params.dma_execution_mode == "eager"):
|
|
with self.cv:
|
|
while self.counts[device_id] < value:
|
|
self.cv.wait()
|
|
self.counts[device_id] -= value
|
|
if self.interpret_params.detect_races:
|
|
clock = copy_vector_clock(self.clocks[device_id])
|
|
if self.interpret_params.detect_races:
|
|
with shared_memory.lock:
|
|
update_vector_clock(shared_memory.clocks[device_id], clock)
|
|
return
|
|
|
|
# For DMA semaphores (when dma_execution_mode=='on_wait'), while our count
|
|
# is not large enough we will select and partially execute pending DMAs
|
|
# until our count is large enough.
|
|
#
|
|
# This approach will tend to run DMAs as late as possible, as well as
|
|
# out-of-order. This approach also lets us avoid the complexity of spinning
|
|
# up separate threads to handle executing DMAs.
|
|
shared_memory = _get_shared_memory()
|
|
while True:
|
|
clock = None
|
|
with self.cv:
|
|
if self.counts[device_id] >= value:
|
|
self.counts[device_id] -= value
|
|
if self.interpret_params.detect_races:
|
|
clock = copy_vector_clock(self.clocks[device_id])
|
|
else:
|
|
return
|
|
if clock is not None:
|
|
with shared_memory.lock:
|
|
update_vector_clock(shared_memory.clocks[device_id], clock)
|
|
return
|
|
|
|
with shared_memory.lock:
|
|
dma_queue = shared_memory.dmas_by_sem[self.id]
|
|
if len(dma_queue) > 0:
|
|
dma = dma_queue.pop()
|
|
else:
|
|
continue
|
|
|
|
# Only execute the DMA as far as necessary to signal us.
|
|
assert (dma.src_sem is self) or (dma.dst_sem is self)
|
|
with dma.lock:
|
|
if dma.virtual_device_id is None:
|
|
dma.virtual_device_id = np.random.randint(
|
|
shared_memory.num_devices, NUM_VIRTUAL_DEVICES)
|
|
|
|
if dma.state == DmaState.STARTED:
|
|
# Do the read.
|
|
if self.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
dma.data = get(dma.src_device_id,
|
|
dma.src_memory_space,
|
|
dma.src_buffer_id,
|
|
dma.src_transforms,
|
|
clock=copy_vector_clock(dma.clock),
|
|
src_device_id=dma.id,
|
|
source_info=dma.source_info)
|
|
if self.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
if dma.src_sem is not None:
|
|
data_size = dma.data.itemsize * dma.data.size
|
|
dma.src_sem.signal(
|
|
data_size, device_id=dma.src_device_id, clock=dma.clock)
|
|
dma.state = DmaState.READ
|
|
|
|
if dma.src_sem is self:
|
|
# We were only waiting for the DMA read (i.e., we're the send
|
|
# semaphore), so leave the DMA write for later.
|
|
continue
|
|
assert dma.state == DmaState.READ
|
|
|
|
# Do the write.
|
|
assert dma.dst_sem is self
|
|
if self.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
store(dma.dst_device_id,
|
|
dma.dst_memory_space,
|
|
dma.dst_buffer_id,
|
|
dma.dst_transforms,
|
|
dma.data,
|
|
clock=copy_vector_clock(dma.clock),
|
|
src_device_id=dma.id,
|
|
source_info=dma.source_info)
|
|
if self.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
data_size = dma.data.itemsize * dma.data.size
|
|
dma.dst_sem.signal(
|
|
data_size, device_id=dma.dst_device_id, clock=dma.clock)
|
|
|
|
dma.data = None
|
|
dma.state = DmaState.COMPLETED
|
|
|
|
|
|
class DmaState(enum.Enum):
|
|
STARTED = 0
|
|
READ = 1
|
|
COMPLETED = 2
|
|
|
|
@dataclasses.dataclass
|
|
class DMA:
|
|
id: int
|
|
|
|
src_device_id: int
|
|
src_memory_space: int
|
|
src_buffer_id: int
|
|
src_transforms: tuple[Any, ...]
|
|
dst_device_id: int
|
|
dst_memory_space: int
|
|
dst_buffer_id: int
|
|
dst_transforms: tuple[Any, ...]
|
|
src_sem: Semaphore
|
|
dst_sem: Semaphore
|
|
|
|
clock: VectorClock
|
|
|
|
source_info: source_info_util.SourceInfo | None = None
|
|
|
|
state: DmaState = DmaState.STARTED
|
|
data: np.ndarray | None = None
|
|
virtual_device_id: int | None = None
|
|
lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RaceDetectionState:
|
|
num_devices: int
|
|
|
|
# (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)]
|
|
reads: dict = dataclasses.field(
|
|
default_factory=lambda: collections.defaultdict(list))
|
|
|
|
# (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)]
|
|
writes: dict = dataclasses.field(
|
|
default_factory=lambda: collections.defaultdict(list))
|
|
|
|
lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
|
|
|
|
races_found: bool = False
|
|
|
|
def _is_empty_slice(slice_or_idx: slice | int):
|
|
if isinstance(slice_or_idx, int) or (slice_or_idx == slice(None)):
|
|
return False
|
|
|
|
# NOTE: All slices here will have known size.
|
|
start = int(slice_or_idx.start) if slice_or_idx.start is not None else 0
|
|
stop = int(slice_or_idx.stop)
|
|
return (start < stop)
|
|
|
|
def slices_overlap(slice_or_idx1: slice | int, slice_or_idx2: slice | int):
|
|
if isinstance(slice_or_idx1, int):
|
|
slice_or_idx1 = slice(slice_or_idx1, slice_or_idx1 + 1)
|
|
if isinstance(slice_or_idx2, int):
|
|
slice_or_idx2 = slice(slice_or_idx2, slice_or_idx2 + 1)
|
|
|
|
if slice_or_idx1 == slice(None):
|
|
return _is_empty_slice(slice_or_idx2)
|
|
if slice_or_idx2 == slice(None):
|
|
return _is_empty_slice(slice_or_idx1)
|
|
|
|
# TODO(jburnim): Handle non-zero steps.
|
|
assert (slice_or_idx1.step == 1) or (slice_or_idx1.step is None)
|
|
assert (slice_or_idx2.step == 1) or (slice_or_idx2.step is None)
|
|
|
|
# NOTE: We are only comparing slices with known stops (and sizes).
|
|
# Do we need to handle zero-length slices?
|
|
return ((slice_or_idx1.start <= slice_or_idx2.start < slice_or_idx1.stop)
|
|
| (slice_or_idx2.start <= slice_or_idx1.start < slice_or_idx2.stop))
|
|
|
|
def ranges_overlap(range1: tuple[slice | int, ...],
|
|
range2: tuple[slice | int, ...]) -> bool:
|
|
return all(slices_overlap(r1, r2) for r1, r2
|
|
in itertools.zip_longest(range1, range2, fillvalue=slice(None)))
|
|
|
|
def check_read(device_id, clock, buffer_key, rnge, source_info=None):
|
|
if source_info is not None:
|
|
user_frame = source_info_util.summarize(source_info)
|
|
else:
|
|
user_frame = 'pallas_call'
|
|
|
|
with races.lock:
|
|
writes = races.writes[buffer_key]
|
|
num_writes = len(writes)
|
|
races.reads[buffer_key].append((device_id, clock, rnge, user_frame))
|
|
|
|
for i in range(num_writes):
|
|
write_device_id, write_clock, write_range, write_frame = writes[i]
|
|
if ordered(write_clock, clock):
|
|
continue
|
|
if not ranges_overlap(rnge, write_range):
|
|
continue
|
|
# TODO(jburnim): When printing device IDs for reads/writes, distinguish
|
|
# between real device IDs vs. DMA IDs.
|
|
print('RACE DETECTED\n'
|
|
f' read of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n'
|
|
f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}')
|
|
with races.lock:
|
|
races.races_found = True
|
|
return
|
|
|
|
def check_write(device_id, clock, buffer_key, rnge, source_info=None):
|
|
if source_info is not None:
|
|
user_frame = source_info_util.summarize(source_info)
|
|
else:
|
|
user_frame = 'pallas_call'
|
|
|
|
with races.lock:
|
|
writes = races.writes[buffer_key]
|
|
reads = races.reads[buffer_key]
|
|
num_writes = len(writes)
|
|
num_reads = len(reads)
|
|
races.writes[buffer_key].append((device_id, clock, rnge, user_frame))
|
|
|
|
# TODO(jburnim): For performance, we should also probably remove any
|
|
# conflicting reads and writes that happened-before the current write.
|
|
|
|
for i in range(num_writes):
|
|
write_device_id, write_clock, write_range, write_frame = writes[i]
|
|
if ordered(write_clock, clock):
|
|
continue
|
|
if not ranges_overlap(rnge, write_range):
|
|
continue
|
|
# TODO(jburnim): When printing device IDs for reads/writes, distinguish
|
|
# between real device IDs vs. DMA IDs.
|
|
print('RACE DETECTED\n'
|
|
f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n'
|
|
f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}')
|
|
with races.lock:
|
|
races.races_found = True
|
|
break
|
|
|
|
for i in range(num_reads):
|
|
read_device_id, read_clock, read_range, read_frame = reads[i]
|
|
if ordered(read_clock, clock):
|
|
continue
|
|
if not ranges_overlap(rnge, read_range):
|
|
continue
|
|
# TODO(jburnim): When printing device IDs for reads/writes, distinguish
|
|
# between real device IDs vs. DMA IDs.
|
|
print('RACE DETECTED\n'
|
|
f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n'
|
|
f' read of {buffer_key}[{read_range}] from {read_device_id}, {read_frame}')
|
|
with races.lock:
|
|
races.races_found = True
|
|
return
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SharedMemory:
|
|
interpret_params: TPUInterpretParams
|
|
num_devices: int
|
|
clocks: list[VectorClock]
|
|
barrier: threading.Barrier
|
|
|
|
# (memory_space, buffer_id, device_id) -> NumPy array
|
|
# TODO(jburnim): Handle Megacore.
|
|
mem: dict[tuple[int, int, int], np.ndarray] = dataclasses.field(
|
|
default_factory=dict)
|
|
|
|
# semaphore_id -> Semaphore
|
|
sem: dict[int, Semaphore] = dataclasses.field(default_factory=dict)
|
|
|
|
# (semaphore_id, device_id)
|
|
# -> list of DMAs that will signal the semaphore on the given device
|
|
dmas_by_sem: dict[tuple[int, int], list[DMA]] = dataclasses.field(
|
|
default_factory=lambda: collections.defaultdict(list))
|
|
|
|
lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
|
|
|
|
# device_id -> next buffer ID
|
|
next_buffer_id: dict[int, int] = dataclasses.field(
|
|
default_factory=lambda: collections.defaultdict(lambda: 100))
|
|
# device_id -> next semaphore ID
|
|
next_semaphore_id: dict[int, int] = dataclasses.field(
|
|
default_factory=lambda: collections.defaultdict(lambda: 2000))
|
|
|
|
next_dma_id: int = 100
|
|
|
|
|
|
# TODO(jburnim): Do we want to support multiple instances of SharedMemory?
|
|
# Maybe for running multiple distinct interpreted computations in parallel?
|
|
_shared_memory : SharedMemory | None = None
|
|
_shared_memory_init_lock = threading.Lock()
|
|
races : RaceDetectionState | None = None
|
|
|
|
def _get_shared_memory() -> SharedMemory:
|
|
assert _shared_memory is not None
|
|
return _shared_memory
|
|
|
|
def _clear_shared_memory():
|
|
global _shared_memory
|
|
with _shared_memory_init_lock:
|
|
_shared_memory = None
|
|
|
|
def _initialize_shared_memory(device_id, num_devices, *, interpret_params):
|
|
global _shared_memory
|
|
del device_id
|
|
num_devices = int(num_devices)
|
|
with _shared_memory_init_lock:
|
|
if _shared_memory is None:
|
|
_shared_memory = SharedMemory(
|
|
interpret_params=interpret_params,
|
|
num_devices=num_devices,
|
|
clocks=[make_vector_clock(num_devices) for _ in range(num_devices)],
|
|
barrier=threading.Barrier(num_devices))
|
|
assert _shared_memory.num_devices == num_devices
|
|
|
|
global races
|
|
races = RaceDetectionState(num_devices=num_devices)
|
|
|
|
def _clean_up_shared_memory(device_id):
|
|
device_id = int(device_id)
|
|
shared_memory = _get_shared_memory()
|
|
shared_memory.barrier.wait()
|
|
if device_id == 0:
|
|
_clear_shared_memory()
|
|
|
|
def _validate(device_id):
|
|
device_id = int(device_id)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
for sem in shared_memory.sem.values():
|
|
with sem.cv:
|
|
if sem.counts[device_id] != 0:
|
|
# TODO(jburnim): Make this raise an error, but in a way that doesn't
|
|
# cause other devices to hang later in `_clean_up_shared_memory`.
|
|
print(
|
|
f'Semaphore {sem.id} has non-zero count for {device_id} at '
|
|
f'kernel exit: {sem.counts[device_id]}')
|
|
|
|
def _allocate_buffer(device_id, memory_space, val):
|
|
device_id = int(device_id)
|
|
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
|
val = np.array(val)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
buffer_id = shared_memory.next_buffer_id[device_id]
|
|
shared_memory.next_buffer_id[device_id] = buffer_id + 1
|
|
# TODO(jburnim): Add options for initializing memory (e.g., with NaNs,
|
|
# with zeros, or with the buffer ID).
|
|
shared_memory.mem[(memory_space, buffer_id, device_id)] = val
|
|
|
|
# TODO(jburnim): Raise an error if buffer_id is too big for int16.
|
|
return np.int16(buffer_id)
|
|
|
|
def _deallocate_buffer(device_id, memory_space, buffer_id):
|
|
device_id = int(device_id)
|
|
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
|
buffer_id = int(buffer_id)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
# TODO(jburnim): Error if buffer doesn't exist?
|
|
shared_memory.mem.pop((memory_space, buffer_id, device_id), None)
|
|
|
|
def _allocate_semaphores(device_id, shape):
|
|
device_id = int(device_id)
|
|
shape = tuple(map(int, shape))
|
|
num_semaphores = math.prod(shape)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
semaphore_id = shared_memory.next_semaphore_id[device_id]
|
|
shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores
|
|
for i in range(semaphore_id, semaphore_id + num_semaphores):
|
|
if i not in shared_memory.sem:
|
|
shared_memory.sem[i] = Semaphore(i)
|
|
|
|
# NOTE: For now, we use a relatively uncommon datatype (int16) for
|
|
# semaphore (and buffer) IDs, so these values are more easily identifiable
|
|
# in kernels.
|
|
#
|
|
# TODO(jburnim): Raise an error if any IDs are too big for int16.
|
|
return np.int16(
|
|
range(semaphore_id, semaphore_id + num_semaphores)
|
|
).reshape(shape)
|
|
|
|
|
|
TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = {
|
|
v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)}
|
|
TPU_MEMORY_SPACE_NAMES = {
|
|
i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)}
|
|
|
|
# Default to VMEM when no memory space is specified.
|
|
TPU_MEMORY_SPACE_IDXS[None] = (
|
|
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM])
|
|
|
|
def get_barrier_semaphore(device_id, collective_id):
|
|
del device_id
|
|
collective_id = int(collective_id)
|
|
|
|
# TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict
|
|
# with IDs for regular or DMA semaphores. (For example, store them in a
|
|
# different table.)
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
semaphore_id = collective_id
|
|
if semaphore_id not in shared_memory.sem:
|
|
shared_memory.sem[semaphore_id] = Semaphore()
|
|
|
|
return np.int16(semaphore_id)
|
|
|
|
def _transform_slice_or_index(slice_or_idx):
|
|
if isinstance(slice_or_idx, int):
|
|
return slice_or_idx
|
|
else:
|
|
start = int(slice_or_idx.start)
|
|
size = int(slice_or_idx.size)
|
|
stride = int(slice_or_idx.stride)
|
|
return slice(start, start + size * stride, stride)
|
|
|
|
def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
|
|
ret = []
|
|
i = 0
|
|
j = 0
|
|
while True:
|
|
if i == len(slice_or_idx1):
|
|
ret.extend(slice_or_idx2[j:])
|
|
return tuple(ret)
|
|
elif j == len(slice_or_idx2):
|
|
ret.extend(slice_or_idx1[i:])
|
|
return tuple(ret)
|
|
elif isinstance(slice_or_idx1[i], int):
|
|
ret.append(slice_or_idx1[i])
|
|
i += 1
|
|
elif isinstance(slice_or_idx2[j], int):
|
|
ret.append(slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step)
|
|
i += 1
|
|
j += 1
|
|
else:
|
|
ret.append(slice(
|
|
slice_or_idx1[i].start + slice_or_idx2[j].start * slice_or_idx1[i].step,
|
|
slice_or_idx1[i].start + slice_or_idx2[j].stop * slice_or_idx1[i].step,
|
|
slice_or_idx1[i].step * slice_or_idx2[j].step
|
|
))
|
|
i += 1
|
|
j += 1
|
|
|
|
def _to_range(transforms) -> tuple[slice | int, ...]:
|
|
ret = ()
|
|
for transform in transforms:
|
|
# For now, assume only NDIndexer transforms.
|
|
ret = _compose_slice_or_index(
|
|
ret, tuple(_transform_slice_or_index(i) for i in transform.indices))
|
|
return ret
|
|
|
|
def get(device_id, memory_space, buffer_id, transforms, *,
|
|
src_device_id=None, clock=None, source_info=None):
|
|
device_id = int(device_id)
|
|
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
|
buffer_id = int(buffer_id)
|
|
try:
|
|
transforms = jax.tree.map(int, transforms)
|
|
except:
|
|
raise ValueError('Advanced indexers are not supported on TPU')
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
read_range = _to_range(transforms)
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
if clock is None:
|
|
clock = copy_vector_clock(shared_memory.clocks[device_id])
|
|
buffer = shared_memory.mem[(memory_space, buffer_id, device_id)]
|
|
ret = buffer[read_range].copy()
|
|
if transforms:
|
|
# TODO(jburnim): Instead of using NDIndexer, do the computation ourselves
|
|
# with buffer.shape and read_range?
|
|
expected_shape = transforms[-1].get_indexer_shape()
|
|
if expected_shape != ret.shape[:len(expected_shape)]:
|
|
raise ValueError(
|
|
f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): '
|
|
f'reading [{read_range}] but bufer has shape {buffer.shape} .')
|
|
|
|
if shared_memory.interpret_params.detect_races:
|
|
if src_device_id is None:
|
|
src_device_id = device_id
|
|
check_read(src_device_id, clock, (memory_space, buffer_id, device_id),
|
|
read_range, source_info=source_info)
|
|
|
|
return ret
|
|
|
|
def store(device_id, memory_space, buffer_id, transforms, val, *,
|
|
src_device_id=None, clock=None, source_info=None):
|
|
device_id = int(device_id)
|
|
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
|
buffer_id = int(buffer_id)
|
|
try:
|
|
transforms = jax.tree.map(int, transforms)
|
|
except:
|
|
raise ValueError('Advanced indexers are not supported on TPU')
|
|
val = np.array(val)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
if clock is None:
|
|
clock = copy_vector_clock(shared_memory.clocks[device_id])
|
|
|
|
buff = shared_memory.mem[(memory_space, buffer_id, device_id)]
|
|
assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically.
|
|
write_range = _to_range(transforms)
|
|
# TODO(jburnim): Better error message if this raises?
|
|
in_bounds_shape = buff[write_range].shape
|
|
if in_bounds_shape != val.shape:
|
|
raise ValueError(
|
|
f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): '
|
|
f'writing [{write_range}] but buffer has shape {buff.shape} .')
|
|
buff[write_range] = val
|
|
|
|
if shared_memory.interpret_params.detect_races:
|
|
if src_device_id is None:
|
|
src_device_id = device_id
|
|
check_write(src_device_id, clock, (memory_space, buffer_id, device_id),
|
|
write_range, source_info=source_info)
|
|
|
|
def swap(device_id, memory_space, buffer_id, transforms, val, mask, *,
|
|
source_info=None):
|
|
device_id = int(device_id)
|
|
memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)]
|
|
buffer_id = int(buffer_id)
|
|
try:
|
|
transforms = jax.tree.map(int, transforms)
|
|
except:
|
|
raise ValueError('Advanced indexers are not supported on TPU')
|
|
val = np.array(val)
|
|
mask = np.array(mask) if mask is not None else None
|
|
if mask is not None:
|
|
assert mask.shape == val.shape
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
clock = copy_vector_clock(shared_memory.clocks[device_id])
|
|
buff = shared_memory.mem[(memory_space, buffer_id, device_id)]
|
|
assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically.
|
|
read_write_range = _to_range(transforms)
|
|
# TODO(jburnim): Better error message if this raises?
|
|
raw_result = buff[read_write_range]
|
|
in_bounds_shape = raw_result.shape
|
|
if mask is None:
|
|
if in_bounds_shape != val.shape:
|
|
raise ValueError(
|
|
f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): '
|
|
f'swapping [{read_write_range}] but buffer has shape {buff.shape} .')
|
|
buff[read_write_range] = val
|
|
return raw_result.copy()
|
|
|
|
in_bounds_mask = np.full(mask.shape, True)
|
|
for i in range(len(in_bounds_shape)):
|
|
in_bounds_mask[in_bounds_shape[i]:] = False
|
|
if (~in_bounds_mask & mask).any():
|
|
# TODO(jburnim): Include indices of out-of-bounds locations where mask
|
|
# is True.
|
|
raise ValueError(
|
|
f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): '
|
|
f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ')
|
|
|
|
in_bounds_idx = tuple(slice(i) for i in in_bounds_shape)
|
|
result = val.copy()
|
|
result[in_bounds_idx] = np.where(
|
|
mask[in_bounds_idx], raw_result, val[in_bounds_idx])
|
|
buff[read_write_range] = np.where(
|
|
mask[in_bounds_idx], val[in_bounds_idx], raw_result)
|
|
|
|
if shared_memory.interpret_params.detect_races:
|
|
check_write(device_id, clock, (memory_space, buffer_id, device_id),
|
|
read_write_range, source_info=source_info)
|
|
return result
|
|
|
|
def execute_dma(dma):
|
|
# TODO(jburnim) Eliminate duplicate code here and in Semaphore.wait.
|
|
shared_memory = _get_shared_memory()
|
|
with dma.lock:
|
|
assert dma.state == DmaState.STARTED
|
|
|
|
if dma.virtual_device_id is None:
|
|
# See comment in Semaphore.wait .
|
|
dma.virtual_device_id = np.random.randint(
|
|
shared_memory.num_devices, NUM_VIRTUAL_DEVICES)
|
|
|
|
# Do the read.
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
dma.data = get(dma.src_device_id,
|
|
dma.src_memory_space,
|
|
dma.src_buffer_id,
|
|
dma.src_transforms,
|
|
clock=copy_vector_clock(dma.clock),
|
|
src_device_id=dma.id,
|
|
source_info=dma.source_info)
|
|
data_size = dma.data.itemsize * dma.data.size
|
|
|
|
# Signal the send semaphore.
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
if dma.src_sem is not None:
|
|
dma.src_sem.signal(
|
|
data_size, device_id=dma.src_device_id, clock=dma.clock)
|
|
dma.state = DmaState.READ
|
|
|
|
# Do the write.
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
store(dma.dst_device_id,
|
|
dma.dst_memory_space,
|
|
dma.dst_buffer_id,
|
|
dma.dst_transforms,
|
|
dma.data,
|
|
clock=copy_vector_clock(dma.clock),
|
|
src_device_id=dma.id,
|
|
source_info=dma.source_info)
|
|
|
|
# Signal the receive semaphore.
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(dma.clock, dma.virtual_device_id)
|
|
if dma.dst_sem is not None:
|
|
dma.dst_sem.signal(
|
|
data_size, device_id=dma.dst_device_id, clock=dma.clock)
|
|
|
|
dma.data = None
|
|
dma.state = DmaState.COMPLETED
|
|
|
|
def print_memory(device_id):
|
|
device_id = int(device_id)
|
|
if all(d == 0 for d in device_id):
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
print(shared_memory.mem)
|
|
|
|
def dma_start(device_id, src_memory_space, src_id, src_transforms,
|
|
dst_memory_space, dst_id, dst_transforms,
|
|
dst_sem_id, src_sem_id, dst_device_id,
|
|
source_info=None):
|
|
device_id = int(device_id)
|
|
src_memory_space, src_id = int(src_memory_space), int(src_id)
|
|
src_transforms = jax.tree.map(int, src_transforms)
|
|
dst_memory_space, dst_id = int(dst_memory_space), int(dst_id)
|
|
dst_transforms = jax.tree.map(int, dst_transforms)
|
|
dst_sem_id = int(dst_sem_id)
|
|
src_sem_id = int(src_sem_id) if src_sem_id is not None else None
|
|
if dst_device_id is not None:
|
|
dst_device_id = int(dst_device_id)
|
|
else:
|
|
dst_device_id = device_id
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
dst_sem = shared_memory.sem[dst_sem_id]
|
|
src_sem = shared_memory.sem[src_sem_id] if src_sem_id is not None else None
|
|
|
|
clock = None
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
clock = copy_vector_clock(shared_memory.clocks[device_id])
|
|
dma_id = shared_memory.next_dma_id
|
|
shared_memory.next_dma_id += 1
|
|
|
|
dma = DMA(
|
|
dma_id,
|
|
device_id, src_memory_space, src_id, src_transforms,
|
|
dst_device_id, dst_memory_space, dst_id, dst_transforms,
|
|
src_sem,
|
|
dst_sem,
|
|
clock=clock,
|
|
source_info=source_info,
|
|
)
|
|
|
|
if shared_memory.interpret_params.dma_execution_mode == 'on_wait':
|
|
shared_memory.dmas_by_sem[dst_sem_id].append(dma)
|
|
if src_sem_id is not None:
|
|
shared_memory.dmas_by_sem[src_sem_id].append(dma)
|
|
return
|
|
|
|
assert shared_memory.interpret_params.dma_execution_mode == 'eager'
|
|
execute_dma(dma)
|
|
|
|
def dma_wait(device_id, sem_id, size):
|
|
device_id = int(device_id)
|
|
sem_id = int(sem_id)
|
|
size = int(size)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
sem = shared_memory.sem[sem_id]
|
|
sem.wait(size, device_id, is_dma=True)
|
|
|
|
def semaphore_signal(device_id, sem_id, inc, target_device_id,
|
|
target_core_index):
|
|
device_id = int(device_id)
|
|
sem_id = int(sem_id)
|
|
inc = int(inc)
|
|
if target_device_id is None:
|
|
target_device_id = device_id
|
|
else:
|
|
target_device_id = int(target_device_id)
|
|
|
|
if target_core_index is not None:
|
|
if int(target_core_index) != 0:
|
|
raise NotImplementedError('semaphore_signal with target_core_index != 0')
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
clock = None
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
clock = copy_vector_clock(shared_memory.clocks[device_id])
|
|
sem = shared_memory.sem[sem_id]
|
|
sem.signal(inc, target_device_id, clock)
|
|
|
|
def semaphore_wait(device_id, sem_id, value):
|
|
device_id = int(device_id)
|
|
sem_id = int(sem_id)
|
|
value = int(value)
|
|
|
|
shared_memory = _get_shared_memory()
|
|
with shared_memory.lock:
|
|
if shared_memory.interpret_params.detect_races:
|
|
inc_vector_clock(shared_memory.clocks[device_id], device_id)
|
|
sem = shared_memory.sem[sem_id]
|
|
sem.wait(value, device_id)
|
|
|
|
def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
|
|
for transform in transforms:
|
|
if transform is None:
|
|
continue
|
|
shape = transform.transform_shape(shape)
|
|
dtype = transform.transform_dtype(dtype)
|
|
return shape, dtype
|
|
|
|
def _device_coords_to_logical_id(device_coords, axis_sizes):
|
|
if not isinstance(device_coords, tuple):
|
|
device_coords = (device_coords,)
|
|
assert len(device_coords) == len(axis_sizes)
|
|
sizes = list(axis_sizes.values())
|
|
ret = 0
|
|
for i in range(len(device_coords)):
|
|
ret += device_coords[i] * math.prod(sizes[i+1:])
|
|
return ret
|
|
|
|
def _device_id_to_logical(device_id, device_id_type, axis_sizes):
|
|
if device_id is None:
|
|
return None
|
|
if device_id_type == mosaic_primitives.DeviceIdType.MESH:
|
|
return _device_coords_to_logical_id(device_id, axis_sizes)
|
|
elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL:
|
|
return device_id
|
|
else:
|
|
raise ValueError(f'Unsupported device ID type: {device_id_type}')
|
|
|
|
@lu.cache
|
|
def _to_jaxpr(flat_fun, in_avals):
|
|
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
|
new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts)
|
|
return new_jaxpr
|
|
|
|
def _is_any(memory_space):
|
|
return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or
|
|
(memory_space == pallas_core.MemorySpace.ANY))
|
|
|
|
def _is_float(dtype):
|
|
return jnp.issubdtype(dtype, jnp.floating)
|
|
|
|
_SENTINEL = jnp.inf
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Placeholder:
|
|
"""Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`."""
|
|
shape: tuple[int, ...]
|
|
dtype: jnp.dtype
|
|
|
|
def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
|
|
env = {}
|
|
|
|
def read(var):
|
|
if isinstance(var, jax_core.Literal):
|
|
result = var.val
|
|
else:
|
|
result = env[var]
|
|
if isinstance(result, Placeholder):
|
|
result = jax.lax.full(result.shape, _SENTINEL, result.dtype)
|
|
return result
|
|
|
|
def write(var, value):
|
|
if interpret_params.skip_floating_point_ops and _is_float(value.dtype):
|
|
value = Placeholder(value.shape, value.dtype)
|
|
env[var] = value
|
|
|
|
jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
|
|
|
|
# Get the device ID.
|
|
axis_sizes = jax_core.get_axis_env().axis_sizes
|
|
device_id = _device_coords_to_logical_id(
|
|
tuple(lax.axis_index(s) for s in axis_sizes.keys()),
|
|
axis_sizes)
|
|
# TODO(jburnim): Pass the device ID around, instead of re-fetching/computing
|
|
# it for each sub-jaxpr.
|
|
|
|
# TODO(jburnim): Clean up and finish this evaluation loop. For example:
|
|
# - Replace the big if-statement with a dictionary of rules.
|
|
# - Handle other higher-order primitives?
|
|
# - Megacore.
|
|
_interpret = functools.partial(
|
|
_interpret_jaxpr, compiler_params=compiler_params,
|
|
interpret_params=interpret_params)
|
|
for eqn in jaxpr.eqns:
|
|
with source_info_util.user_context(
|
|
eqn.source_info.traceback, name_stack=eqn.source_info.name_stack):
|
|
prim = eqn.primitive
|
|
# We defer reading the values for `eqn.invars` into each of the branches
|
|
# of the if-elif-else statement below. This is because the else branch may
|
|
# not need to do any reads if `interpret_params.skip_floating_point_ops`
|
|
# is True. If this is the case, we want to avoid materializing the read
|
|
# array into the jaxpr when this function is traced.
|
|
deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars)
|
|
|
|
if prim is primitives.load_p:
|
|
(ref, transforms, mask, _) = jax.tree.unflatten(
|
|
eqn.params['args_tree'], deferred_invals())
|
|
if mask is not None:
|
|
raise NotImplementedError('masked load_p')
|
|
out = callback.io_callback(
|
|
functools.partial(get, source_info=eqn.source_info),
|
|
eqn.outvars[0].aval,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
|
ref,
|
|
transforms,
|
|
ordered=True)
|
|
|
|
elif prim is primitives.swap_p:
|
|
(ref, transforms, val, mask) = jax.tree.unflatten(
|
|
eqn.params['args_tree'], deferred_invals())
|
|
out = callback.io_callback(
|
|
functools.partial(swap, source_info=eqn.source_info),
|
|
eqn.outvars[0].aval,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
|
ref,
|
|
transforms,
|
|
val,
|
|
mask,
|
|
ordered=True)
|
|
|
|
elif prim is mosaic_primitives.delay_p:
|
|
out = []
|
|
|
|
elif prim is lax.cond_p:
|
|
def _make_branch(jaxpr):
|
|
return lambda *args: _interpret(jaxpr, *args)
|
|
invals = deferred_invals()
|
|
out = lax.switch(
|
|
invals[0],
|
|
[_make_branch(branch_jaxpr.jaxpr)
|
|
for branch_jaxpr in eqn.params['branches']],
|
|
*invals[1:])
|
|
|
|
elif prim is lax.scan_p:
|
|
consts, init_carry, xs = split_list(
|
|
deferred_invals(),
|
|
[eqn.params['num_consts'], eqn.params['num_carry']],
|
|
)
|
|
def _scan_body(c, a):
|
|
return split_list(
|
|
_interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a),
|
|
[eqn.params['num_carry']])
|
|
carry, out = lax.scan(_scan_body, init_carry, xs=xs,
|
|
length=eqn.params.get('length', None))
|
|
out = carry + out
|
|
|
|
elif prim is lax.while_p:
|
|
cond_consts, body_consts, init_vals = split_list(
|
|
deferred_invals(),
|
|
[eqn.params['cond_nconsts'], eqn.params['body_nconsts']],
|
|
)
|
|
out = lax.while_loop(
|
|
lambda args: _interpret(
|
|
eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0],
|
|
lambda args: _interpret(
|
|
eqn.params['body_jaxpr'].jaxpr, *body_consts, *args),
|
|
init_vals)
|
|
|
|
elif prim is for_loop.for_p:
|
|
raise NotImplementedError('for_p')
|
|
|
|
elif prim is pjit.pjit_p:
|
|
def f(*args, jaxpr):
|
|
return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args)
|
|
invals = deferred_invals()
|
|
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
|
|
new_jaxpr = _to_jaxpr(
|
|
lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']),
|
|
debug_info=eqn.params['jaxpr'].jaxpr.debug_info),
|
|
in_avals)
|
|
out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr}))
|
|
|
|
elif prim is primitives.run_scoped_p:
|
|
# Allocate a buffer or semaphore for each element of
|
|
# eqn.params['jaxpr'].invars .
|
|
allocs = []
|
|
for v in eqn.params['jaxpr'].invars:
|
|
if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
|
|
allocs.append(callback.io_callback(
|
|
_allocate_semaphores,
|
|
jax.ShapeDtypeStruct(v.aval.shape, jnp.int16),
|
|
device_id,
|
|
v.aval.shape,
|
|
ordered=True))
|
|
else:
|
|
allocs.append(callback.io_callback(
|
|
_allocate_buffer,
|
|
jax.ShapeDtypeStruct((), jnp.int16),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
|
_uninitialized_value(
|
|
v.aval.shape, v.aval.dtype, interpret_params),
|
|
ordered=True))
|
|
|
|
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
|
|
|
|
for a in allocs:
|
|
if isinstance(a, tuple):
|
|
callback.io_callback(
|
|
_deallocate_buffer,
|
|
None,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
|
a,
|
|
ordered=True)
|
|
else:
|
|
# TODO(jburnim): De-allocate semaphores.
|
|
# callback.io_callback(
|
|
# _deallocate_semaphores,
|
|
# None,
|
|
# device_id,
|
|
# a,
|
|
# ordered=True)
|
|
pass
|
|
|
|
elif prim is state_primitives.get_p:
|
|
invals = deferred_invals()
|
|
out = callback.io_callback(
|
|
functools.partial(get, source_info=eqn.source_info),
|
|
eqn.outvars[0].aval,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
|
invals[0],
|
|
jax.tree.unflatten(eqn.params['tree'], invals[1:]),
|
|
ordered=True)
|
|
|
|
elif prim is state_primitives.swap_p:
|
|
invals = deferred_invals()
|
|
out = callback.io_callback(
|
|
functools.partial(swap, source_info=eqn.source_info),
|
|
eqn.outvars[0].aval,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space],
|
|
invals[0],
|
|
jax.tree.unflatten(eqn.params['tree'], invals[2:]),
|
|
invals[1],
|
|
None,
|
|
ordered=True)
|
|
|
|
elif prim is mosaic_primitives.dma_start_p:
|
|
(
|
|
src,
|
|
src_transforms,
|
|
dst,
|
|
dst_transforms,
|
|
dst_sem,
|
|
dst_sem_transforms,
|
|
src_sem,
|
|
src_sem_transforms,
|
|
target_device_id,
|
|
) = jax.tree.unflatten(eqn.params['tree'], deferred_invals())
|
|
target_device_id = _device_id_to_logical(
|
|
target_device_id, eqn.params['device_id_type'], axis_sizes)
|
|
(orig_src_ref, _, orig_dst_ref, *_
|
|
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
|
|
callback.io_callback(
|
|
functools.partial(dma_start, source_info=eqn.source_info),
|
|
(),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)],
|
|
src, src_transforms,
|
|
TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)],
|
|
dst, dst_transforms,
|
|
state_discharge.transform_array(dst_sem, dst_sem_transforms),
|
|
state_discharge.transform_array(src_sem, src_sem_transforms),
|
|
target_device_id,
|
|
ordered=True)
|
|
out = []
|
|
|
|
elif prim is mosaic_primitives.dma_wait_p:
|
|
(
|
|
src,
|
|
src_transforms,
|
|
dst,
|
|
dst_transforms,
|
|
dst_sem,
|
|
dst_sem_transforms,
|
|
src_sem,
|
|
src_sem_transforms,
|
|
target_device_id,
|
|
) = jax.tree.unflatten(eqn.params['tree'], deferred_invals())
|
|
read_shape, read_dtype = _compute_transformed_shape_and_dtype(
|
|
eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms)
|
|
callback.io_callback(
|
|
dma_wait,
|
|
(),
|
|
device_id,
|
|
state_discharge.transform_array(dst_sem, dst_sem_transforms),
|
|
math.prod(read_shape) * read_dtype.itemsize,
|
|
ordered=True)
|
|
out = []
|
|
|
|
elif prim is mosaic_primitives.get_barrier_semaphore_p:
|
|
out = callback.io_callback(
|
|
get_barrier_semaphore,
|
|
jax.ShapeDtypeStruct((), jnp.int16),
|
|
device_id,
|
|
compiler_params['mosaic']['collective_id'],
|
|
ordered=True)
|
|
|
|
elif prim is mosaic_primitives.semaphore_signal_p:
|
|
sem, sem_transforms, inc, target_device_id, core_index = (
|
|
jax.tree.unflatten(eqn.params['args_tree'], deferred_invals()))
|
|
target_device_id = _device_id_to_logical(
|
|
target_device_id, eqn.params['device_id_type'], axis_sizes)
|
|
callback.io_callback(
|
|
semaphore_signal,
|
|
(),
|
|
device_id,
|
|
state_discharge.transform_array(sem, sem_transforms),
|
|
inc,
|
|
target_device_id,
|
|
core_index,
|
|
ordered=True)
|
|
out = []
|
|
|
|
elif prim is mosaic_primitives.semaphore_wait_p:
|
|
sem, sem_transforms, value = (
|
|
jax.tree.unflatten(eqn.params['args_tree'], deferred_invals()))
|
|
callback.io_callback(
|
|
semaphore_wait,
|
|
(),
|
|
device_id,
|
|
state_discharge.transform_array(sem, sem_transforms),
|
|
value,
|
|
ordered=True)
|
|
out = []
|
|
|
|
elif prim is primitives.atomic_rmw_p:
|
|
raise NotImplementedError('atomic_rmw_p')
|
|
|
|
elif prim is primitives.atomic_cas_p:
|
|
raise NotImplementedError('atomic_cas_p')
|
|
|
|
else:
|
|
if interpret_params.skip_floating_point_ops and all(
|
|
_is_float(ovar.aval.dtype) for ovar in eqn.outvars
|
|
):
|
|
# Skip `prim.bind` since `prim` only produces floating-point values.
|
|
# It is safe to populate `out` with avals since mapping `write` over
|
|
# `out` below only relies on the shape and dtype (for writing
|
|
# `Placeholder`s).
|
|
out = [ovar.aval for ovar in eqn.outvars]
|
|
if not prim.multiple_results:
|
|
out = out[0]
|
|
else:
|
|
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
|
out = prim.bind(*subfuns, *deferred_invals(), **bind_params)
|
|
|
|
out = out if prim.multiple_results else [out]
|
|
jax.util.safe_map(write, eqn.outvars, out)
|
|
|
|
return jax.util.safe_map(read, jaxpr.outvars)
|
|
|
|
def _initialize_output_vals(
|
|
block_mappings_output: Iterable[BlockMapping],
|
|
input_args, input_output_aliases,
|
|
interpret_params: TPUInterpretParams,
|
|
) -> Sequence[jax.Array]:
|
|
oi_map = {v: k for k, v in input_output_aliases}
|
|
output_vals = []
|
|
for i, bm in enumerate(block_mappings_output):
|
|
if i in oi_map:
|
|
output_vals.append(input_args[oi_map[i]])
|
|
else:
|
|
output_vals.append(_uninitialized_value(
|
|
bm.array_shape_dtype.shape,
|
|
bm.array_shape_dtype.dtype,
|
|
interpret_params))
|
|
return output_vals
|
|
|
|
def _compute_start_indices(block_mapping, loop_idx, *args):
|
|
block_indices = (
|
|
jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args))
|
|
if isinstance(block_mapping.indexing_mode, pallas_core.Blocked):
|
|
ret = tuple(i if b is pallas_core.mapped else b * i
|
|
for b, i in zip(block_mapping.block_shape, block_indices))
|
|
elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked):
|
|
ret = block_indices
|
|
else:
|
|
raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}")
|
|
return ret
|
|
|
|
def _get_next_indices(grid, indices):
|
|
next_indices = []
|
|
carry = True
|
|
for dim_size, index in reversed(list(zip(grid, indices))):
|
|
i = jnp.where(carry, index + 1, index)
|
|
carry = dim_size == i
|
|
next_indices.append(jnp.where(carry, 0, i))
|
|
return tuple(reversed(next_indices))
|
|
|
|
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
|
start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx)
|
|
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
|
|
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
|
|
dtype=np.bool_)])
|
|
return lax.squeeze(output, squeeze_dims)
|
|
|
|
def _uninitialized_value(shape, dtype, interpret_params):
|
|
if interpret_params.uninitialized_memory == 'nan':
|
|
if jnp.issubdtype(dtype, jnp.floating):
|
|
return jnp.full(shape, jnp.nan, dtype)
|
|
elif jnp.issubdtype(dtype, jnp.integer):
|
|
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
|
|
elif jnp.issubdtype(dtype, jnp.bool):
|
|
return jnp.full(shape, False, dtype)
|
|
if interpret_params.uninitialized_memory == 'zero':
|
|
return jnp.full(shape, 0, dtype)
|
|
raise NotImplementedError(
|
|
interpret_params.uninitialized_memory + ' + ' + str(dtype))
|
|
|
|
def _pad_to_block_dimension(value, block_shape, interpret_params):
|
|
"""Pads values so the shape evenly divides into block dimensions.
|
|
|
|
For example, if values has a shape of (33, 2, 5) with a block_shape of
|
|
(32, 2, 4), this function will pad the value of shape to (64, 2, 8).
|
|
|
|
Args:
|
|
value: Array to be padded.
|
|
block_shape: Block shapes to use for padding. If None, no padding will
|
|
be performed.
|
|
|
|
Returns:
|
|
A padded array.
|
|
"""
|
|
padded_shape = tuple(
|
|
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
|
|
)
|
|
if padded_shape != value.shape:
|
|
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
|
pad_value = _uninitialized_value((), value.dtype, interpret_params)
|
|
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
|
return value
|
|
|
|
def get_interpret_effects():
|
|
return {callback._OrderedIOEffect}
|
|
|
|
def interpret_pallas_call(
|
|
*args,
|
|
jaxpr: jax_core.Jaxpr,
|
|
debug: bool,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
grid_mapping: GridMapping,
|
|
mesh: pallas_core.Mesh | None,
|
|
compiler_params: Any,
|
|
cost_estimate: CostEstimate,
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
|
interpret_params: TPUInterpretParams,
|
|
):
|
|
del debug, mesh, cost_estimate, out_avals
|
|
|
|
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
|
|
dynamic_grid_args, scalars, input_args = split_list(
|
|
args,
|
|
[grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands],
|
|
)
|
|
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
|
grid = tuple(
|
|
a if a is not pallas_core.dynamic_grid_dim
|
|
else next(dynamic_grid_args_iter)
|
|
for a in grid_mapping.grid
|
|
)
|
|
assert next(dynamic_grid_args_iter, None) is None
|
|
|
|
axis_sizes = jax_core.get_axis_env().axis_sizes
|
|
num_devices = functools.reduce(
|
|
jnp.multiply, axis_sizes.values(), jnp.int32(1))
|
|
device_id = _device_coords_to_logical_id(
|
|
tuple(lax.axis_index(s) for s in axis_sizes.keys()),
|
|
axis_sizes)
|
|
callback.io_callback(
|
|
functools.partial(
|
|
_initialize_shared_memory, interpret_params=interpret_params),
|
|
(),
|
|
device_id,
|
|
num_devices,
|
|
ordered=True)
|
|
|
|
# Pad input arguments.
|
|
is_indexing_dim = [
|
|
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
|
for bm in grid_mapping.block_mappings
|
|
]
|
|
block_shapes = [
|
|
tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
|
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
|
]
|
|
num_inputs = grid_mapping.num_inputs
|
|
input_args = [
|
|
_pad_to_block_dimension(a, bs, interpret_params)
|
|
for a, bs in zip(input_args, block_shapes[:num_inputs])
|
|
]
|
|
|
|
# Allocate buffers in HBM for outputs.
|
|
output_buffer_ids = []
|
|
output_buffer_shapes = []
|
|
output_vals = _initialize_output_vals(
|
|
grid_mapping.block_mappings_output,
|
|
scalars + input_args,
|
|
input_output_aliases,
|
|
interpret_params)
|
|
num_outputs = grid_mapping.num_outputs
|
|
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
|
|
for out_val, bs in zip(output_vals, output_block_shapes):
|
|
padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
|
|
output_buffer_shapes.append(padded_val.shape)
|
|
output_buffer_ids.append(callback.io_callback(
|
|
_allocate_buffer,
|
|
jax.ShapeDtypeStruct((), jnp.int16),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
|
padded_val,
|
|
ordered=True))
|
|
# Allocate buffers for all kernel arguments (e.g., scalars, inputs,
|
|
# outputs, scratch).
|
|
io_alias_map = dict(input_output_aliases)
|
|
oi_alias_map = {v: k for k, v in input_output_aliases}
|
|
kernel_buffer_ids = []
|
|
for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars):
|
|
kernel_buffer_ids.append(callback.io_callback(
|
|
_allocate_buffer,
|
|
jax.ShapeDtypeStruct((), jnp.int16),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM],
|
|
val,
|
|
ordered=True))
|
|
for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]):
|
|
output_idx = i - grid_mapping.num_inputs
|
|
is_input = i < grid_mapping.num_inputs
|
|
is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs)
|
|
if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE:
|
|
kernel_buffer_ids.append(callback.io_callback(
|
|
_allocate_semaphores,
|
|
jax.ShapeDtypeStruct(var.aval.shape, jnp.int16),
|
|
device_id,
|
|
var.aval.shape,
|
|
ordered=True))
|
|
elif is_output and _is_any(var.aval.memory_space):
|
|
# Use the already-allocated HBM output buffer.
|
|
#
|
|
# TODO(jburnim): For kernel args in HBM, check that block shape is the
|
|
# same as for the corresponding pallas_call input, and that the index_map
|
|
# is trivial.
|
|
kernel_buffer_ids.append(output_buffer_ids[output_idx])
|
|
elif is_output and (output_idx in oi_alias_map):
|
|
# Use the already-allocated (non-HBM) input buffer.
|
|
kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]])
|
|
elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space):
|
|
# Use the already-allocated HBM output buffer.
|
|
kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]])
|
|
else:
|
|
# TODO(jburnim): For kernel args in HBM, check that block shape is the
|
|
# same as for the corresponding pallas_call input, and that the index_map
|
|
# is trivial.
|
|
kernel_buffer_ids.append(callback.io_callback(
|
|
_allocate_buffer,
|
|
jax.ShapeDtypeStruct((), jnp.int16),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
|
_uninitialized_value(
|
|
var.aval.shape, var.aval.dtype, interpret_params),
|
|
ordered=True))
|
|
|
|
_, input_ids, kernel_output_ids, _ = split_list(
|
|
kernel_buffer_ids,
|
|
[grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs])
|
|
input_vars, output_vars = split_list(
|
|
jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs])
|
|
|
|
# For kernel inputs that are in HBM, we populate the buffer once before
|
|
# any kernel invocations.
|
|
for buffer_id, var, val in zip(input_ids, input_vars, input_args):
|
|
if not _is_any(var.aval.memory_space):
|
|
continue
|
|
if (val.shape != var.aval.shape) or (val.dtype != var.aval.dtype):
|
|
# TODO(jburnim): Also check that the index_map is trivial.
|
|
raise ValueError()
|
|
callback.io_callback(
|
|
store,
|
|
(),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
|
buffer_id,
|
|
(),
|
|
val,
|
|
ordered=True)
|
|
|
|
if grid:
|
|
num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type]
|
|
else:
|
|
# Base case is always one iteration when grid is ()
|
|
num_iterations = 1
|
|
|
|
def body(carry):
|
|
# The loop carry: (i, loop_idx) --
|
|
# - i:int32 is the interation index
|
|
# - loop_idx: tuple[int32] are the program ids for each grid axis
|
|
i, loop_idx = carry
|
|
|
|
if grid_mapping.local_grid_env is not None:
|
|
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
|
|
else:
|
|
local_grid_env = tuple(
|
|
pallas_core.GridAxis(idx, b)
|
|
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
|
if dim not in grid_mapping.vmapped_dims
|
|
)
|
|
|
|
with pallas_core.grid_env(local_grid_env):
|
|
# Copy slices of the input to the kernel buffers.
|
|
#
|
|
# TODO(jburnim): Only copy slices when the index mapping has changed?
|
|
start_indices = [_compute_start_indices(bm, loop_idx, *scalars)
|
|
for bm in grid_mapping.block_mappings]
|
|
for j, var in enumerate(input_vars):
|
|
if _is_any(var.aval.memory_space):
|
|
continue
|
|
sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j],
|
|
input_args[j], is_indexing_dim[j])
|
|
assert(sliced_val.shape == var.aval.shape)
|
|
callback.io_callback(
|
|
# TODO(jburnim): Pass source_info from the pallas_call, in case this
|
|
# store is involved in a data race.
|
|
store,
|
|
(),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
|
input_ids[j],
|
|
(),
|
|
sliced_val,
|
|
ordered=True)
|
|
|
|
# Invoke the kernel.
|
|
_interpret_jaxpr(jaxpr, *kernel_buffer_ids,
|
|
compiler_params=compiler_params,
|
|
interpret_params=interpret_params)
|
|
|
|
# Copy from the kernel buffers to slices of the output in HBM.
|
|
#
|
|
# TODO(jburnim): Only copy if the index mapping will change in the
|
|
# next iteration (or if this is the last iteration)?
|
|
for j, var in enumerate(output_vars):
|
|
if _is_any(var.aval.memory_space):
|
|
continue
|
|
kernel_output_val = callback.io_callback(
|
|
# TODO(jburnim): Pass source_info from the pallas_call, in case this
|
|
# get is involved in a data race.
|
|
get,
|
|
var.aval,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
|
kernel_output_ids[j],
|
|
(),
|
|
ordered=True)
|
|
transform = indexing.NDIndexer(
|
|
indices=tuple(indexing.ds(st, sz) if not iid else st
|
|
for st, sz, iid in zip(start_indices[num_inputs + j],
|
|
block_shapes[num_inputs + j],
|
|
is_indexing_dim[num_inputs + j])),
|
|
shape=output_vals[j].shape,
|
|
int_indexer_shape=())
|
|
callback.io_callback(
|
|
# TODO(jburnim): Pass source_info from the pallas_call, in case this
|
|
# store is involved in a data race.
|
|
store,
|
|
(),
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
|
output_buffer_ids[j],
|
|
(transform,),
|
|
kernel_output_val,
|
|
ordered=True)
|
|
|
|
return i + 1, _get_next_indices(grid, loop_idx)
|
|
|
|
# TODO(jburnim): Handle parallel grid dimensions + megacore.
|
|
_ = lax.while_loop(
|
|
lambda carry: carry[0] < num_iterations,
|
|
body,
|
|
(jnp.int32(0), (jnp.int32(0),) * len(grid))
|
|
)
|
|
|
|
# Read the output from the allocated output buffers.
|
|
ret = [
|
|
callback.io_callback(
|
|
# TODO(jburnim): Pass source_info from the pallas_call, in case this
|
|
# get is involved in a data race.
|
|
get,
|
|
val,
|
|
device_id,
|
|
TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY],
|
|
output_buffer_id,
|
|
(indexing.NDIndexer.from_indices_shape(
|
|
tuple(indexing.ds(0, s) for s in val.shape),
|
|
output_buffer_shape),),
|
|
ordered=True)
|
|
for val, output_buffer_id, output_buffer_shape in zip(
|
|
output_vals, output_buffer_ids, output_buffer_shapes)
|
|
]
|
|
|
|
callback.io_callback(
|
|
_validate,
|
|
(),
|
|
device_id,
|
|
ordered=True)
|
|
|
|
# For now, when we're done with a pallas_call, we delete the shared memory.
|
|
# We use a barrier to ensure that all devices are done running the kernel.
|
|
#
|
|
# TODO(jburnim): Get rid of this barrier. And figure out how this should
|
|
# work if we want to invoke successive pallas_calls that use the same
|
|
# shared memory.
|
|
callback.io_callback(
|
|
_clean_up_shared_memory,
|
|
(),
|
|
device_id,
|
|
ordered=True)
|
|
|
|
return ret
|