2023-08-01 16:42:26 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Module for calling pallas functions from JAX."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-01-27 17:51:50 -08:00
|
|
|
from collections.abc import Callable, Sequence
|
2024-08-20 15:38:03 -07:00
|
|
|
import dataclasses
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
import enum
|
2024-05-14 14:47:24 -07:00
|
|
|
from functools import partial, reduce
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
import types
|
2024-10-22 05:36:37 -07:00
|
|
|
from typing import Any, Literal
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax import lax
|
2024-05-30 07:57:55 -07:00
|
|
|
from jax._src import ad_util
|
2024-07-18 15:33:40 +02:00
|
|
|
from jax._src import api_util
|
2024-06-17 17:14:58 -07:00
|
|
|
from jax._src import checkify
|
2024-05-14 14:47:24 -07:00
|
|
|
from jax._src import config
|
2024-05-30 07:57:55 -07:00
|
|
|
from jax._src import core as jax_core
|
|
|
|
from jax._src import effects
|
|
|
|
from jax._src import linear_util as lu
|
2024-10-01 16:29:59 -07:00
|
|
|
from jax._src import state
|
2024-07-02 00:40:13 -07:00
|
|
|
from jax._src import tree_util
|
2024-01-25 22:20:36 -08:00
|
|
|
from jax._src.interpreters import ad
|
|
|
|
from jax._src.interpreters import batching
|
|
|
|
from jax._src.interpreters import mlir
|
2024-05-30 07:57:55 -07:00
|
|
|
from jax._src.interpreters import partial_eval as pe
|
|
|
|
from jax._src.pallas import core as pallas_core
|
2024-09-04 22:17:19 +01:00
|
|
|
from jax._src.pallas import primitives
|
2025-02-04 15:38:28 -08:00
|
|
|
from jax._src.pallas import helpers as pallas_helpers
|
2025-01-27 17:51:50 -08:00
|
|
|
from jax._src.pallas import hlo_interpreter
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.state import discharge as state_discharge
|
2024-10-01 03:30:15 -07:00
|
|
|
from jax._src.state import types as state_types
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.util import (
|
2024-06-05 08:14:39 -07:00
|
|
|
safe_map,
|
|
|
|
safe_zip,
|
|
|
|
split_list,
|
|
|
|
tuple_insert,
|
2024-07-02 00:40:13 -07:00
|
|
|
unzip2,
|
2024-06-05 08:14:39 -07:00
|
|
|
weakref_lru_cache,
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
import jax.numpy as jnp
|
|
|
|
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
|
|
|
Grid = pallas_core.Grid
|
2024-07-23 15:25:14 +03:00
|
|
|
TupleGrid = pallas_core.TupleGrid
|
2023-08-01 16:42:26 -07:00
|
|
|
GridSpec = pallas_core.GridSpec
|
|
|
|
BlockMapping = pallas_core.BlockMapping
|
|
|
|
GridMapping = pallas_core.GridMapping
|
2024-06-07 12:07:07 +01:00
|
|
|
BlockSpec = pallas_core.BlockSpec
|
|
|
|
BlockSpecTree = pallas_core.BlockSpecTree
|
2023-09-07 17:08:18 -07:00
|
|
|
NoBlockSpec = pallas_core.NoBlockSpec
|
|
|
|
no_block_spec = pallas_core.no_block_spec
|
2024-09-18 05:25:37 -07:00
|
|
|
ScratchShapeTree = pallas_core.ScratchShapeTree
|
2024-08-05 08:17:18 -07:00
|
|
|
CostEstimate = pallas_core.CostEstimate
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-07-25 01:49:59 -07:00
|
|
|
# See the docstring for GridMapping for the calling convention
|
2023-08-01 16:42:26 -07:00
|
|
|
pallas_call_p = jax_core.Primitive('pallas_call')
|
|
|
|
pallas_call_p.multiple_results = True
|
|
|
|
|
2024-02-02 16:37:16 -08:00
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
def _pallas_call_impl(*args, **params):
|
|
|
|
# Call the lowering path
|
|
|
|
@partial(jax.jit, inline=True)
|
|
|
|
def _jit_run(*args):
|
|
|
|
return pallas_call_p.bind(*args, **params)
|
|
|
|
return _jit_run(*args)
|
2024-07-01 13:53:41 +02:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
pallas_call_p.def_impl(_pallas_call_impl)
|
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
|
|
|
|
def _pallas_call_abstract_eval(
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
*avals,
|
|
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
|
|
|
interpret,
|
|
|
|
backend,
|
|
|
|
**params
|
2024-09-18 20:38:54 -07:00
|
|
|
):
|
|
|
|
del avals
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
|
|
|
|
if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams):
|
|
|
|
# Report effects that will be introduced when running/lowering
|
|
|
|
# mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call .
|
|
|
|
effs = mosaic_tpu_interpret.get_interpret_effects()
|
|
|
|
else:
|
|
|
|
effs = jax_core.no_effects
|
|
|
|
|
2024-09-19 04:39:11 -07:00
|
|
|
# Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
|
|
|
|
return [
|
|
|
|
jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
|
|
|
|
if isinstance(a, pallas_core.ShapedArrayWithMemorySpace)
|
|
|
|
else a
|
|
|
|
for a in out_avals
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
], effs
|
2024-09-18 20:38:54 -07:00
|
|
|
|
|
|
|
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
pallas_call_p.def_effectful_abstract_eval(_pallas_call_abstract_eval)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-08-05 08:17:18 -07:00
|
|
|
|
|
|
|
def _pallas_call_jvp_rule(
|
|
|
|
primals,
|
|
|
|
tangents,
|
|
|
|
*,
|
2025-01-24 10:57:28 +02:00
|
|
|
jaxpr: jax_core.Jaxpr,
|
2024-08-05 08:17:18 -07:00
|
|
|
name_and_src_info,
|
2023-12-08 12:09:04 +00:00
|
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
2024-08-05 08:17:18 -07:00
|
|
|
grid_mapping,
|
2025-01-24 10:57:28 +02:00
|
|
|
debug: bool,
|
|
|
|
interpret: bool,
|
2024-08-05 08:17:18 -07:00
|
|
|
compiler_params: Any,
|
|
|
|
cost_estimate: CostEstimate | None,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
2024-10-22 05:36:37 -07:00
|
|
|
backend: _Backend | None,
|
2024-08-05 08:17:18 -07:00
|
|
|
):
|
2024-02-01 09:14:30 -08:00
|
|
|
if grid_mapping.num_dynamic_grid_bounds:
|
|
|
|
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
2023-08-01 16:42:26 -07:00
|
|
|
if grid_mapping.num_index_operands:
|
|
|
|
raise NotImplementedError
|
|
|
|
if input_output_aliases:
|
|
|
|
raise NotImplementedError("JVP with aliasing not supported.")
|
|
|
|
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
|
|
|
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
2024-07-23 15:25:14 +03:00
|
|
|
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
|
2023-08-01 16:42:26 -07:00
|
|
|
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
|
|
|
|
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
|
|
|
|
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
|
|
|
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
|
|
|
|
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
|
|
|
|
# `Ref`s that are read from followed by output `Ref`s that are written to.
|
|
|
|
# This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new
|
|
|
|
# jaxpr that has tangents following primals. In order for this jaxpr to be
|
|
|
|
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
|
|
|
|
# the jaxpr's invars.
|
2023-08-11 02:29:54 -07:00
|
|
|
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
|
2024-07-23 15:25:14 +03:00
|
|
|
jvp_jaxpr.invars, [len(primals), grid_mapping.num_outputs, len(tangents)]
|
2023-08-11 02:29:54 -07:00
|
|
|
)
|
|
|
|
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
|
2024-05-30 07:57:55 -07:00
|
|
|
effs = []
|
|
|
|
for eff in jvp_jaxpr.effects:
|
|
|
|
if isinstance(eff, effects.JaxprInputEffect):
|
|
|
|
eff = eff.replace(
|
|
|
|
input_index=invars.index(jvp_jaxpr.invars[eff.input_index])
|
|
|
|
)
|
|
|
|
effs.append(eff)
|
|
|
|
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
|
2023-08-01 16:42:26 -07:00
|
|
|
if debug:
|
2024-08-05 04:23:15 -07:00
|
|
|
print(f"\nThe jaxpr for the jvp of pallas_call {name_and_src_info}:")
|
2023-08-01 16:42:26 -07:00
|
|
|
print(jvp_jaxpr)
|
2023-08-11 02:29:54 -07:00
|
|
|
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
|
|
|
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
2024-07-25 01:49:59 -07:00
|
|
|
jvp_grid_mapping = grid_mapping.replace(
|
|
|
|
block_mappings=jvp_bms,
|
|
|
|
num_inputs=grid_mapping.num_inputs * 2,
|
|
|
|
num_outputs=grid_mapping.num_outputs * 2,
|
|
|
|
)
|
2024-08-05 08:17:18 -07:00
|
|
|
if cost_estimate is not None:
|
|
|
|
jvp_cost_estimate = CostEstimate(
|
|
|
|
flops=2 * cost_estimate.flops,
|
|
|
|
bytes_accessed=2 * cost_estimate.bytes_accessed,
|
|
|
|
transcendentals=2 * cost_estimate.transcendentals,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
jvp_cost_estimate = None
|
2023-08-11 02:29:54 -07:00
|
|
|
out_flat = pallas_call_p.bind(
|
|
|
|
*primals,
|
|
|
|
*tangents,
|
|
|
|
jaxpr=jvp_jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info=name_and_src_info.replace(
|
2024-08-05 08:17:18 -07:00
|
|
|
name=f"{name_and_src_info.name}_jvp"
|
|
|
|
),
|
2024-07-25 01:49:59 -07:00
|
|
|
grid_mapping=jvp_grid_mapping,
|
2023-08-11 02:29:54 -07:00
|
|
|
interpret=interpret,
|
|
|
|
debug=debug,
|
|
|
|
input_output_aliases=(),
|
2024-03-06 09:15:36 -08:00
|
|
|
compiler_params=compiler_params,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate=jvp_cost_estimate,
|
2024-10-22 05:36:37 -07:00
|
|
|
out_avals=(*out_avals, *out_avals),
|
|
|
|
backend=backend,
|
2023-08-11 02:29:54 -07:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
|
|
|
return out_primals, out_tangents
|
2024-08-05 08:17:18 -07:00
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
|
|
|
|
def _batch_block_mapping(
|
|
|
|
grid_mapping: GridMapping,
|
|
|
|
axis_size: int,
|
2024-10-14 14:00:58 -07:00
|
|
|
for_ragged: bool,
|
2024-08-20 15:06:27 -07:00
|
|
|
aval: jax_core.ShapedArray,
|
|
|
|
dim: int | batching.NotMapped,
|
|
|
|
block_mapping: BlockMapping,
|
2024-10-14 14:00:58 -07:00
|
|
|
ragged_axis_values,
|
2024-08-20 15:06:27 -07:00
|
|
|
) -> BlockMapping:
|
2023-08-01 16:42:26 -07:00
|
|
|
def _block_map_function(new_idx, *args):
|
2024-08-20 15:06:27 -07:00
|
|
|
if for_ragged:
|
|
|
|
drop_last_args = args[:-1]
|
|
|
|
else:
|
|
|
|
drop_last_args = args
|
|
|
|
|
|
|
|
indices = jax_core.eval_jaxpr(
|
|
|
|
block_mapping.index_map_jaxpr.jaxpr,
|
|
|
|
block_mapping.index_map_jaxpr.consts,
|
|
|
|
*drop_last_args,
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
if dim is not batching.not_mapped:
|
2024-08-20 15:06:27 -07:00
|
|
|
if isinstance(dim, batching.RaggedAxis):
|
|
|
|
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
|
|
|
stacked_axis = dim.stacked_axis
|
|
|
|
indices.insert(stacked_axis, new_idx)
|
|
|
|
else:
|
|
|
|
indices.insert(dim, new_idx)
|
2023-08-01 16:42:26 -07:00
|
|
|
return tuple(indices)
|
2024-07-23 15:25:14 +03:00
|
|
|
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
|
2024-08-20 15:06:27 -07:00
|
|
|
|
|
|
|
if for_ragged:
|
|
|
|
if isinstance(dim, batching.RaggedAxis):
|
|
|
|
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
2024-10-14 14:00:58 -07:00
|
|
|
_, _, _, lengths_aval = ragged_axis_values
|
2024-08-20 15:06:27 -07:00
|
|
|
idx_avals = [*idx_avals, lengths_aval]
|
|
|
|
else:
|
|
|
|
i32_aval_memref = pallas_core.AbstractMemoryRef(
|
|
|
|
jax_core.ShapedArray(([axis_size]), jnp.int32),
|
|
|
|
pallas_core.MemorySpace.INDEX,
|
|
|
|
)
|
|
|
|
idx_avals = [*idx_avals, i32_aval_memref]
|
|
|
|
|
2024-06-17 15:17:52 -07:00
|
|
|
with grid_mapping.trace_env():
|
|
|
|
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
2025-02-07 10:15:47 +02:00
|
|
|
lu.wrap_init(_block_map_function,
|
|
|
|
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
|
|
|
|
idx_avals)
|
2024-07-25 01:49:59 -07:00
|
|
|
shape = block_mapping.block_shape
|
2023-08-01 16:42:26 -07:00
|
|
|
if dim is batching.not_mapped:
|
|
|
|
new_block_shape = shape
|
2024-07-25 01:49:59 -07:00
|
|
|
new_array_shape_dtype = block_mapping.array_shape_dtype
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
2024-08-20 15:06:27 -07:00
|
|
|
if isinstance(dim, batching.RaggedAxis):
|
|
|
|
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
|
|
|
new_block_shape = shape
|
|
|
|
stacked_axis = dim.stacked_axis
|
|
|
|
new_block_shape = tuple_insert(
|
|
|
|
new_block_shape, stacked_axis, pallas_core.mapped
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
|
|
|
|
|
|
|
array_shape = block_mapping.array_shape_dtype.shape
|
|
|
|
if isinstance(dim, batching.RaggedAxis):
|
|
|
|
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
|
|
|
stacked_axis = dim.stacked_axis
|
|
|
|
array_shape = tuple_insert(array_shape, stacked_axis, axis_size)
|
|
|
|
else:
|
|
|
|
array_shape = tuple_insert(array_shape, dim, axis_size)
|
|
|
|
|
2024-07-25 01:49:59 -07:00
|
|
|
new_array_shape_dtype = jax.ShapeDtypeStruct(
|
2024-08-20 15:06:27 -07:00
|
|
|
array_shape, block_mapping.array_shape_dtype.dtype
|
|
|
|
)
|
2024-07-25 01:49:59 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
|
|
|
return block_mapping.replace(block_shape=new_block_shape,
|
2024-07-25 01:49:59 -07:00
|
|
|
array_shape_dtype=new_array_shape_dtype,
|
2023-08-01 16:42:26 -07:00
|
|
|
index_map_jaxpr=jaxpr)
|
|
|
|
|
2024-06-05 08:14:39 -07:00
|
|
|
|
|
|
|
def _broadcast_input_output_aliases(
|
|
|
|
args: Sequence[jax.Array],
|
|
|
|
dims: Sequence[int | batching.NotMapped],
|
|
|
|
*,
|
|
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
|
|
axis_size: int,
|
|
|
|
) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]:
|
|
|
|
"""Broadcast input/output operands.
|
|
|
|
|
|
|
|
When we have input/output aliasing, since the output will be mapped, we need
|
|
|
|
to make sure to broadcast the input across that dimension if it is not
|
2024-06-18 05:28:12 -07:00
|
|
|
mapped. If the input is mapped, but on a different axis, we tranpose the input
|
|
|
|
to match the output.
|
2024-06-05 08:14:39 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
args_ = list(args)
|
|
|
|
dims_ = list(dims)
|
|
|
|
for input_index, _ in input_output_aliases:
|
|
|
|
dim = dims_[input_index]
|
2024-06-18 05:28:12 -07:00
|
|
|
dims_[input_index] = 0
|
2024-10-14 14:00:58 -07:00
|
|
|
if isinstance(dim, batching.RaggedAxis):
|
|
|
|
stacked_axis = dim.stacked_axis
|
|
|
|
if stacked_axis != 0:
|
|
|
|
raise NotImplementedError("Ragged aliasing on non 0 dim NYI")
|
|
|
|
return tuple(args_), tuple(dims_)
|
|
|
|
|
2024-06-05 08:14:39 -07:00
|
|
|
if dim is batching.not_mapped:
|
|
|
|
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
2024-06-18 05:28:12 -07:00
|
|
|
elif dim != 0:
|
|
|
|
# TODO(cjfj): Change output batching axis instead?
|
|
|
|
args_[input_index] = jnp.moveaxis(args[input_index], dim, 0)
|
2024-06-05 08:14:39 -07:00
|
|
|
|
|
|
|
return tuple(args_), tuple(dims_)
|
|
|
|
|
|
|
|
|
|
|
|
def _batch_with_explicit_loop(
|
|
|
|
args: Sequence[jax.Array],
|
|
|
|
dims: Sequence[int | batching.NotMapped],
|
|
|
|
*,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info: pallas_core.NameAndSrcInfo,
|
2024-06-05 08:14:39 -07:00
|
|
|
grid_mapping: GridMapping,
|
|
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
|
|
debug: bool,
|
|
|
|
interpret: bool,
|
|
|
|
compiler_params: Any,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate: CostEstimate | None,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
2024-10-22 05:36:37 -07:00
|
|
|
backend: _Backend | None,
|
2024-06-05 08:14:39 -07:00
|
|
|
):
|
|
|
|
"""Batch the pallas_call by calling it in loop over the batch size.
|
|
|
|
|
|
|
|
This function provides a fallback implementation of batching a pallas_call
|
|
|
|
for the cases in which adding a batch dimension to the pallas grid is not
|
|
|
|
supported. This is currently the case when the batched dimension corresponds
|
|
|
|
to a dynamic axis or a scalar prefetch argument.
|
|
|
|
|
|
|
|
This implementation builds a HLO loop that dynamic_slices the inputs according
|
|
|
|
to the current iteration index and dynamic_updates an (initially empty) output
|
|
|
|
allocation.
|
|
|
|
"""
|
|
|
|
if not dims:
|
|
|
|
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
|
|
|
|
|
|
|
(axis_size,) = {
|
|
|
|
arg.shape[dim]
|
|
|
|
for arg, dim in zip(args, dims)
|
|
|
|
if dim is not batching.not_mapped
|
|
|
|
}
|
|
|
|
|
|
|
|
args, dims = _broadcast_input_output_aliases(
|
|
|
|
args,
|
|
|
|
dims,
|
|
|
|
input_output_aliases=input_output_aliases,
|
|
|
|
axis_size=axis_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
# The output arrays are completelly overwritten, so we can just initialize
|
|
|
|
# empty arrays.
|
|
|
|
initial_state = [
|
2024-07-23 15:25:14 +03:00
|
|
|
jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size),
|
|
|
|
dtype=bm.array_shape_dtype.dtype)
|
|
|
|
for bm in grid_mapping.block_mappings_output
|
2024-06-05 08:14:39 -07:00
|
|
|
]
|
|
|
|
|
|
|
|
def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
|
|
|
|
batch_args = []
|
|
|
|
|
|
|
|
for arg, dim in zip(args, dims):
|
|
|
|
# If the argument is mapped, extract a slice of size 1 in the mapped
|
|
|
|
# dimension at the current index.
|
|
|
|
if dim is batching.not_mapped:
|
|
|
|
batch_args.append(arg)
|
|
|
|
else:
|
|
|
|
batch_args.append(
|
|
|
|
jnp.squeeze(
|
|
|
|
jax.lax.dynamic_slice_in_dim(
|
|
|
|
operand=arg,
|
|
|
|
start_index=batch_index,
|
|
|
|
slice_size=1,
|
|
|
|
axis=dim,
|
|
|
|
),
|
|
|
|
axis=dim,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
batch_out = pallas_call_p.bind(
|
|
|
|
*batch_args,
|
|
|
|
jaxpr=jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info=name_and_src_info,
|
2024-06-05 08:14:39 -07:00
|
|
|
grid_mapping=grid_mapping,
|
|
|
|
input_output_aliases=input_output_aliases,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
|
|
|
compiler_params=compiler_params,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate=cost_estimate,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals=out_avals,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend=backend,
|
2024-06-05 08:14:39 -07:00
|
|
|
)
|
|
|
|
for i, batch_out_array in enumerate(batch_out):
|
|
|
|
state[i] = jax.lax.dynamic_update_index_in_dim(
|
|
|
|
state[i],
|
|
|
|
batch_out_array,
|
|
|
|
batch_index,
|
|
|
|
axis=0,
|
|
|
|
)
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False)
|
|
|
|
|
|
|
|
return result, (0,) * len(result)
|
|
|
|
|
|
|
|
|
|
|
|
def _pallas_call_batching_rule(
|
|
|
|
args,
|
|
|
|
dims,
|
|
|
|
*,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info: pallas_core.NameAndSrcInfo,
|
2024-06-05 08:14:39 -07:00
|
|
|
grid_mapping: GridMapping,
|
|
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
|
|
debug: bool,
|
|
|
|
interpret: bool,
|
|
|
|
compiler_params: Any,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate: CostEstimate | None,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
2024-10-22 05:36:37 -07:00
|
|
|
backend: _Backend | None,
|
2024-06-05 08:14:39 -07:00
|
|
|
):
|
2024-02-01 16:42:46 -08:00
|
|
|
def _maybe_squeeze_out_bdim(
|
|
|
|
x: jax.Array, bdim: int | batching.NotMapped
|
|
|
|
) -> jax.Array:
|
|
|
|
if bdim is batching.not_mapped:
|
|
|
|
return x
|
|
|
|
return jnp.squeeze(x, axis=bdim)
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
def get_size(i, x, d):
|
|
|
|
if not isinstance(d, batching.RaggedAxis):
|
|
|
|
return x.shape[d]
|
2024-10-14 14:00:58 -07:00
|
|
|
return x.aval.shape[d.stacked_axis]
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-08-05 08:17:18 -07:00
|
|
|
(axis_size,) = {
|
2024-08-20 15:06:27 -07:00
|
|
|
get_size(i=i, x=x, d=d)
|
|
|
|
for i, (x, d) in enumerate(zip(args, dims))
|
|
|
|
if d is not batching.not_mapped
|
2024-08-05 08:17:18 -07:00
|
|
|
}
|
2024-06-13 20:24:07 -07:00
|
|
|
if axis_size == 1:
|
|
|
|
# Why are we even vmapping?
|
|
|
|
args = map(_maybe_squeeze_out_bdim, args, dims)
|
|
|
|
out = pallas_call_p.bind(
|
|
|
|
*args,
|
|
|
|
jaxpr=jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info=name_and_src_info,
|
2024-06-13 20:24:07 -07:00
|
|
|
grid_mapping=grid_mapping,
|
|
|
|
input_output_aliases=input_output_aliases,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
|
|
|
compiler_params=compiler_params,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate=cost_estimate,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals=out_avals,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend=backend,
|
2024-06-13 20:24:07 -07:00
|
|
|
)
|
|
|
|
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
|
|
|
|
|
2024-06-05 08:14:39 -07:00
|
|
|
# The first num_dynamic_grid_bounds arguments are size-1 arrays that store
|
|
|
|
# the size of the dynamic bounds.
|
2024-02-01 15:38:46 -08:00
|
|
|
dynamic_grid_args, args = split_list(
|
|
|
|
args, [grid_mapping.num_dynamic_grid_bounds]
|
|
|
|
)
|
|
|
|
dynamic_grid_dims, dims = split_list(
|
|
|
|
dims, [grid_mapping.num_dynamic_grid_bounds]
|
|
|
|
)
|
2024-02-01 16:42:46 -08:00
|
|
|
if all(
|
|
|
|
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
|
|
|
for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims)
|
|
|
|
):
|
|
|
|
dynamic_grid_args = safe_map(
|
2024-06-05 08:14:39 -07:00
|
|
|
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims
|
|
|
|
)
|
2024-02-01 16:42:46 -08:00
|
|
|
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
|
2024-06-05 08:14:39 -07:00
|
|
|
# TODO(amagni, sharadmv): Explore possibility of batching dynamic grid
|
|
|
|
# bounds.
|
|
|
|
return _batch_with_explicit_loop(
|
|
|
|
args=dynamic_grid_args + args,
|
|
|
|
dims=dynamic_grid_dims + dims,
|
|
|
|
jaxpr=jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info=name_and_src_info,
|
2024-06-05 08:14:39 -07:00
|
|
|
grid_mapping=grid_mapping,
|
|
|
|
input_output_aliases=input_output_aliases,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
|
|
|
compiler_params=compiler_params,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate=cost_estimate,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals=out_avals,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend=backend,
|
2024-02-01 16:42:46 -08:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
pass # No dynamic grid dimensions
|
2024-02-01 15:38:46 -08:00
|
|
|
del dynamic_grid_dims
|
2023-08-01 16:42:26 -07:00
|
|
|
if grid_mapping.num_index_operands:
|
2024-01-23 13:56:29 -08:00
|
|
|
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
|
|
|
|
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
|
|
|
|
# Ordinarily, adding support for scalar prefetch in vmap would involve
|
|
|
|
# modifying the block specs in a nontrivial way. However, if we are only
|
|
|
|
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
# and pretend we were never vmapped over them at all.
|
2024-01-23 13:56:29 -08:00
|
|
|
if all(
|
|
|
|
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
|
|
|
for arg, bdim in zip(scalar_args, scalar_bdims)
|
|
|
|
):
|
2024-02-01 16:42:46 -08:00
|
|
|
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
scalar_bdims = [batching.not_mapped] * len(scalar_args)
|
2024-01-23 13:56:29 -08:00
|
|
|
args = (*scalar_args, *args)
|
|
|
|
dims = (*scalar_bdims, *bdims)
|
|
|
|
else:
|
2024-06-05 08:14:39 -07:00
|
|
|
# TODO(amagni,sharadmv,apaszke): enable efficient batching over
|
|
|
|
# prefetched scalar args.
|
|
|
|
return _batch_with_explicit_loop(
|
|
|
|
args=scalar_args + args,
|
|
|
|
dims=scalar_bdims + bdims,
|
|
|
|
jaxpr=jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info=name_and_src_info,
|
2024-06-05 08:14:39 -07:00
|
|
|
grid_mapping=grid_mapping,
|
|
|
|
input_output_aliases=input_output_aliases,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
|
|
|
compiler_params=compiler_params,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate=cost_estimate,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals=out_avals,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend=backend,
|
2024-06-05 08:14:39 -07:00
|
|
|
)
|
|
|
|
|
2024-02-01 16:42:46 -08:00
|
|
|
if not dims:
|
|
|
|
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
2023-08-01 16:42:26 -07:00
|
|
|
block_mappings = grid_mapping.block_mappings
|
|
|
|
avals = [v.aval for v in jaxpr.invars]
|
|
|
|
# How should we pick output dimensions? This actually matters because XLA
|
|
|
|
# can't optimize our pallas kernels, and this layout impacts performance. For
|
|
|
|
# now, because `vmap` doesn't really offer a way of inferring good output
|
|
|
|
# dimensions. For now, we just use 0.
|
|
|
|
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
|
|
|
|
# TODO(sharadmv): explore a long term solution to output dim inference
|
|
|
|
|
2024-06-05 08:14:39 -07:00
|
|
|
args, dims = _broadcast_input_output_aliases(
|
|
|
|
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
# Each dim either has data about its ragged axis, or None
|
|
|
|
ragged_axis_values = []
|
|
|
|
for d in dims:
|
|
|
|
if isinstance(d, batching.RaggedAxis):
|
|
|
|
stacked_axis, ragged_axis_dim, ragged_axis_length = (
|
|
|
|
batching._ragged_axis_parts(d)
|
|
|
|
)
|
|
|
|
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
|
|
|
|
if isinstance(aval, jax_core.DShapedArray):
|
|
|
|
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
|
|
|
|
lengths_aval = pallas_core.AbstractMemoryRef(
|
|
|
|
aval,
|
|
|
|
pallas_core.MemorySpace.INDEX,
|
|
|
|
)
|
|
|
|
# TODO(mvoz): Give this its own type
|
|
|
|
ragged_axis_values.append(
|
|
|
|
(stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
ragged_axis_values.append(None) # type: ignore[arg-type]
|
|
|
|
|
2024-07-23 15:25:14 +03:00
|
|
|
all_dims = list(dims) + [0] * grid_mapping.num_outputs
|
2024-10-14 14:00:58 -07:00
|
|
|
ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs # type: ignore[list-item]
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
num_index_operands = grid_mapping.num_index_operands
|
2023-12-07 15:02:17 -08:00
|
|
|
num_scratch_operands = grid_mapping.num_scratch_operands
|
|
|
|
|
|
|
|
# Only add a batch dimension for the avals that actually have a grid mapping.
|
|
|
|
# This excludes scalar prefetch inputs (the first in the list) and scratch
|
|
|
|
# operands (the last in the list).
|
|
|
|
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
|
2024-10-14 14:00:58 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
batched_block_mappings = map(
|
2024-08-20 15:06:27 -07:00
|
|
|
partial(
|
|
|
|
_batch_block_mapping,
|
|
|
|
grid_mapping,
|
|
|
|
axis_size,
|
2024-10-14 14:00:58 -07:00
|
|
|
any(ragged_axis_values),
|
2024-08-20 15:06:27 -07:00
|
|
|
),
|
2023-12-07 15:02:17 -08:00
|
|
|
avals_to_batch,
|
|
|
|
all_dims[num_index_operands:],
|
|
|
|
block_mappings,
|
2024-10-14 14:00:58 -07:00
|
|
|
ragged_axis_values[num_index_operands:],
|
2023-12-07 15:02:17 -08:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-07-23 15:25:14 +03:00
|
|
|
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(
|
|
|
|
grid_mapping.index_map_avals)
|
2024-07-25 01:49:59 -07:00
|
|
|
assert not index_map_tree_kwargs
|
2024-07-23 15:25:14 +03:00
|
|
|
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
lengths_aval = None # type: ignore[assignment]
|
|
|
|
|
|
|
|
# Check all the ragged axis values, ensure their raggedness pattern
|
|
|
|
# is identical (consider moving this check up!)
|
|
|
|
for rav in ragged_axis_values:
|
|
|
|
if rav is not None:
|
|
|
|
if lengths_aval is None:
|
|
|
|
lengths_aval = rav[3]
|
|
|
|
else:
|
|
|
|
assert lengths_aval == rav[3], "NYI - different lengths in ragged batch"
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
if lengths_aval:
|
|
|
|
batched_index_map_args = batched_index_map_args + (lengths_aval,)
|
|
|
|
num_index_operands += 1
|
|
|
|
|
2024-07-23 15:25:14 +03:00
|
|
|
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
|
|
|
|
(batched_index_map_args, {}))
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
batched_grid_mapping = grid_mapping.replace(
|
|
|
|
grid=(axis_size, *grid_mapping.grid),
|
|
|
|
block_mappings=tuple(batched_block_mappings),
|
2024-08-20 15:06:27 -07:00
|
|
|
index_map_avals=tuple(batched_index_map_avals),
|
2024-07-25 01:49:59 -07:00
|
|
|
index_map_tree=batched_index_map_tree,
|
2024-08-20 15:06:27 -07:00
|
|
|
num_index_operands=num_index_operands,
|
2024-08-05 08:17:18 -07:00
|
|
|
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims),
|
|
|
|
)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-08-05 08:17:18 -07:00
|
|
|
if cost_estimate is not None:
|
|
|
|
batched_cost_estimate = CostEstimate(
|
|
|
|
flops=cost_estimate.flops * axis_size,
|
|
|
|
bytes_accessed=cost_estimate.bytes_accessed * axis_size,
|
|
|
|
transcendentals=cost_estimate.transcendentals * axis_size,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
batched_cost_estimate = None
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
# Start the ragged handling code
|
|
|
|
# Here, we:
|
|
|
|
# - Rewrite the indexer to save memory (skip indices outside the ragged bounds)
|
|
|
|
# - Rewrite the kernel to save compute (skip elements outside the ragged bounds)
|
|
|
|
# - Update various internal structures/metadata to account for the new
|
|
|
|
# block spec.
|
|
|
|
# - Set the hacky flag of ragged_originating on the mapping, to signal to
|
|
|
|
# the lowering code to treat mapped dimensions as part of the user grid.
|
2024-08-20 15:06:27 -07:00
|
|
|
if lengths_aval:
|
|
|
|
batched_grid_mapping = batched_grid_mapping.replace(
|
|
|
|
get_grid_indices=lambda indices, maybe_include_mapped_dims: indices,
|
|
|
|
local_grid_env=lambda loop_idx, grid: tuple(
|
|
|
|
pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Note - on zero filling counterfactuals
|
|
|
|
# A debug util to produce a counterfactual version of the when
|
|
|
|
# gating, where for all values that don't pass the @when check,
|
|
|
|
# we write 0s. This is useful for debugging, as certain lowering paths
|
|
|
|
# like mosaic will write the last data as passthrough, leading to
|
|
|
|
# potentially confusing results.
|
2024-10-14 14:00:58 -07:00
|
|
|
block_mapped_dim_idxs = []
|
2024-08-20 15:06:27 -07:00
|
|
|
for block_mapping in batched_grid_mapping.block_mappings:
|
2024-10-14 14:00:58 -07:00
|
|
|
mapped_dim_idxs = []
|
|
|
|
for i, d in enumerate(block_mapping.block_shape):
|
|
|
|
if d is pallas_core.mapped:
|
|
|
|
mapped_dim_idxs.append(i)
|
|
|
|
else:
|
|
|
|
mapped_dim_idxs.append(None) # type: ignore[arg-type]
|
|
|
|
block_mapped_dim_idxs.append(mapped_dim_idxs)
|
|
|
|
|
|
|
|
mapped_dim_idx = None
|
|
|
|
for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs):
|
|
|
|
if rav is not None:
|
|
|
|
stacked_axis = rav[0]
|
|
|
|
if mapped_dim_idx is None:
|
|
|
|
mapped_dim_idx = mapped_dim_idxs[stacked_axis]
|
|
|
|
if mapped_dim_idxs[stacked_axis] is None:
|
|
|
|
raise ValueError(
|
|
|
|
f"Expected mapped dim to be {stacked_axis}, but got"
|
|
|
|
f" {mapped_dim_idxs[stacked_axis]}"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], (
|
|
|
|
f"Different mapped dims - expected {mapped_dim_idx}, but got"
|
|
|
|
f" {mapped_dim_idxs[stacked_axis]}"
|
|
|
|
)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
|
|
|
# This is the blockspec size of the dimension
|
2024-10-14 14:00:58 -07:00
|
|
|
block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings]
|
|
|
|
|
|
|
|
# Parse out the operations from the jaxpr to determine how to mask the output
|
|
|
|
# NOTE! while this *could* be a default dict of None, and None is sound, as
|
|
|
|
# it denotes that there is no raggedness for the given var, we explicitly
|
|
|
|
# do not do this, so as to get better signal on implementation of rules
|
|
|
|
# A misimplemented rule that does not account for new vars being introduced
|
|
|
|
# will result in an error on the next op using the new var. The benefit of
|
|
|
|
# of forcing implementers to account for all outputs and intermediaries is
|
|
|
|
# a very nice one.
|
|
|
|
|
|
|
|
var_to_raggedness = {}
|
|
|
|
for invar, rav in zip(jaxpr.invars, ragged_axis_values):
|
|
|
|
var_to_raggedness[invar] = rav
|
|
|
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
prim = eqn.primitive
|
|
|
|
if prim not in batching.ragged_prop_rules:
|
|
|
|
raise NotImplementedError(f"Not implemented - ragged prop for {prim}")
|
|
|
|
rule = batching.ragged_prop_rules[prim]
|
|
|
|
|
|
|
|
invar_raggedness = [
|
|
|
|
(
|
|
|
|
var_to_raggedness.get(invar, None)
|
|
|
|
if isinstance(invar, jax_core.Var)
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
for invar in eqn.invars
|
|
|
|
]
|
2024-10-25 12:06:59 -07:00
|
|
|
try:
|
|
|
|
invar_raggedness, outvar_raggedness = rule(
|
|
|
|
eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type]
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:"
|
|
|
|
f" {eqn.outvars}. Underlying reason: {e}"
|
|
|
|
) from e
|
2024-10-14 14:00:58 -07:00
|
|
|
|
|
|
|
for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment]
|
|
|
|
if isinstance(invar, jax_core.Var):
|
|
|
|
var_to_raggedness[invar] = rav
|
|
|
|
for outvar, rav in zip(eqn.outvars, outvar_raggedness):
|
|
|
|
if isinstance(outvar, jax_core.Var):
|
|
|
|
var_to_raggedness[outvar] = rav
|
|
|
|
|
|
|
|
for pos, invar in enumerate(jaxpr.invars):
|
|
|
|
ragged_axis_values[pos] = var_to_raggedness[invar]
|
|
|
|
|
2025-02-12 13:27:29 +00:00
|
|
|
per_input_ragged_axis_dim: list[int | None] = []
|
2024-10-14 14:00:58 -07:00
|
|
|
for rav in ragged_axis_values:
|
|
|
|
if rav is not None:
|
|
|
|
per_input_ragged_axis_dim.append(rav[1])
|
|
|
|
else:
|
|
|
|
per_input_ragged_axis_dim.append(None)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
|
|
|
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
|
2024-10-14 14:00:58 -07:00
|
|
|
b_idx = primitives.program_id(mapped_dim_idx)
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
b_len = lengths_ref[b_idx]
|
2024-10-14 14:00:58 -07:00
|
|
|
run_kernel = jnp.array(True)
|
|
|
|
for i, _ in enumerate(args):
|
|
|
|
ragged_axis_dim = per_input_ragged_axis_dim[i]
|
|
|
|
if ragged_axis_dim is None:
|
|
|
|
continue
|
|
|
|
arg_i_idx = (
|
|
|
|
primitives.program_id(ragged_axis_dim)
|
|
|
|
* block_shapes[i][ragged_axis_dim]
|
|
|
|
)
|
|
|
|
run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
|
|
|
# TODO(mvoz): Unimplemented primitive in pallas
|
|
|
|
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
|
|
|
|
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
|
|
|
|
|
2025-02-04 15:38:28 -08:00
|
|
|
@pallas_helpers.when(run_kernel)
|
2024-08-20 15:06:27 -07:00
|
|
|
def f():
|
|
|
|
# Important! This allows us to trace the inner kernel with the correct
|
|
|
|
# grid to preserve user program_id semantics. Ex: program_id(0) will
|
|
|
|
# always be analogous to program_id(1) in the outer kernel.
|
|
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
|
|
jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs)
|
|
|
|
|
|
|
|
kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars]
|
|
|
|
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(
|
|
|
|
list(kernel_avals)
|
|
|
|
)
|
2024-10-14 14:00:58 -07:00
|
|
|
|
|
|
|
def _rewrite_index_jaxpr(enumerate_batched_block_mapping):
|
|
|
|
arg_pos, batched_block_mapping = enumerate_batched_block_mapping
|
|
|
|
indexer_avals = [
|
|
|
|
v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars
|
|
|
|
]
|
|
|
|
flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten(
|
|
|
|
list(indexer_avals)
|
|
|
|
)
|
|
|
|
|
|
|
|
def index_rewrite_kernel(*indexer_args):
|
|
|
|
ragged_axis_dim = per_input_ragged_axis_dim[arg_pos]
|
|
|
|
|
|
|
|
# the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means
|
|
|
|
# that the indexer for output isnt getting written - before, it always was
|
|
|
|
|
|
|
|
lengths_ref = indexer_args[-1]
|
|
|
|
rest_indexer_args = indexer_args[:-1]
|
|
|
|
# Lengths are always the last argument of the indexer.
|
|
|
|
# lengths_ref = args[-1]
|
|
|
|
# Invariant: Stacked axis is enforced to be the mapped axis above.
|
|
|
|
b_idx = indexer_args[mapped_dim_idx]
|
|
|
|
|
|
|
|
nargs = list(rest_indexer_args)
|
|
|
|
|
|
|
|
if ragged_axis_dim is not None:
|
|
|
|
val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim]
|
|
|
|
|
|
|
|
# The current index into the ragged dimension.
|
|
|
|
# Invariant: There is only one ragged dimension, enforced above.
|
|
|
|
i_idx = indexer_args[ragged_axis_dim]
|
|
|
|
|
|
|
|
# grid space -> element space
|
|
|
|
i_len = i_idx * val_at_ragged_dim
|
|
|
|
|
|
|
|
# The length of the current batch.
|
|
|
|
b_len = lengths_ref[b_idx]
|
|
|
|
|
|
|
|
# Have we reached the end of the current batch?
|
|
|
|
not_done = i_len < b_len
|
|
|
|
|
|
|
|
am_last_batch = b_idx == axis_size - 1
|
|
|
|
last_good_block = lax.div(b_len, val_at_ragged_dim) - 1
|
|
|
|
|
|
|
|
# The logic below can be thought of as:
|
|
|
|
# if index_oob_ragged:
|
|
|
|
# if not last_batch:
|
|
|
|
# batch_idx += 1
|
|
|
|
# ragged_idx = 0
|
|
|
|
# else:
|
|
|
|
# ragged_idx = last_good_block
|
|
|
|
#
|
|
|
|
# wherein we find the next good block by incrementing the batch index
|
|
|
|
# and setting the ragged index to 0 if we are not in the last batch.
|
|
|
|
# Otherwise, we set the ragged index to the last good block.
|
|
|
|
b_next = jnp.where(
|
|
|
|
not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1)
|
|
|
|
)
|
|
|
|
i_next = jnp.where(
|
|
|
|
not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0)
|
|
|
|
)
|
|
|
|
nargs[ragged_axis_dim] = i_next
|
|
|
|
nargs[mapped_dim_idx] = b_next
|
|
|
|
|
|
|
|
nargs = nargs + [lengths_ref]
|
|
|
|
return jax_core.eval_jaxpr(
|
|
|
|
batched_block_mapping.index_map_jaxpr.jaxpr,
|
|
|
|
batched_block_mapping.index_map_jaxpr.consts,
|
|
|
|
*nargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
index_jaxpr, _ = _trace_kernel_to_jaxpr(
|
|
|
|
index_rewrite_kernel,
|
|
|
|
"index_rewrite_kernel",
|
|
|
|
batched_grid_mapping,
|
|
|
|
tuple(flat_indexer_avals),
|
|
|
|
indexer_in_tree,
|
|
|
|
tuple(() for _ in flat_indexer_avals),
|
|
|
|
indexer=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
batched_block_mapping = batched_block_mapping.replace(
|
|
|
|
index_map_jaxpr=pe.close_jaxpr(index_jaxpr)
|
|
|
|
)
|
|
|
|
return batched_block_mapping
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
# Important! This allows us to trace the outer kernel with the correct grid
|
|
|
|
# to enable accessing the batch program_id.
|
|
|
|
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
|
2024-10-14 14:00:58 -07:00
|
|
|
batched_block_mappings = map(
|
|
|
|
_rewrite_index_jaxpr, enumerate(batched_block_mappings)
|
|
|
|
)
|
|
|
|
|
|
|
|
batched_grid_mapping = batched_grid_mapping.replace(
|
|
|
|
block_mappings=tuple(batched_block_mappings),
|
|
|
|
)
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
|
|
|
|
|
2024-10-01 16:29:59 -07:00
|
|
|
jaxpr, consts = _trace_kernel_to_jaxpr(
|
2024-08-20 15:06:27 -07:00
|
|
|
when_wrapped_kernel,
|
|
|
|
kernel_src_info,
|
|
|
|
batched_grid_mapping,
|
|
|
|
tuple(flat_kernel_avals),
|
|
|
|
kernel_in_tree,
|
2024-10-01 03:30:15 -07:00
|
|
|
tuple(() for _ in flat_kernel_avals),
|
2024-08-20 15:06:27 -07:00
|
|
|
)
|
2024-10-01 16:29:59 -07:00
|
|
|
if consts:
|
|
|
|
raise NotImplementedError("consts not supported in pallas_call")
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
# We need to rewrite the input_output_aliases here, the initial call
|
|
|
|
# to broadcast is done, and we have inseted a new input (lengths), so
|
|
|
|
# there's an off-by-one here now.
|
|
|
|
new_input_output_aliases = []
|
|
|
|
for k, v in input_output_aliases:
|
|
|
|
new_input_output_aliases.append((k + 1, v))
|
|
|
|
input_output_aliases = tuple(new_input_output_aliases)
|
|
|
|
|
|
|
|
# assert ragged_axis_length is not None
|
2024-08-20 15:06:27 -07:00
|
|
|
args = (ragged_axis_length, *args)
|
2024-09-18 20:38:54 -07:00
|
|
|
assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
|
2025-01-22 16:47:58 -08:00
|
|
|
|
|
|
|
batched_out_avals = []
|
|
|
|
for aval in out_avals:
|
|
|
|
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
|
|
|
|
if config.sharding_in_types.value else None)
|
|
|
|
shape = tuple_insert(aval.shape, 0, axis_size)
|
|
|
|
batched_out_avals.append(aval.update(shape=shape, sharding=sharding))
|
|
|
|
batched_out_avals = tuple(batched_out_avals) # type: ignore
|
|
|
|
|
2024-02-01 15:38:46 -08:00
|
|
|
out = pallas_call_p.bind(
|
|
|
|
*dynamic_grid_args,
|
|
|
|
*args,
|
|
|
|
jaxpr=jaxpr,
|
2024-08-05 04:23:15 -07:00
|
|
|
name_and_src_info=name_and_src_info.replace(
|
2024-08-05 08:17:18 -07:00
|
|
|
name=f"{name_and_src_info.name}_batched"
|
|
|
|
),
|
2024-02-01 15:38:46 -08:00
|
|
|
grid_mapping=batched_grid_mapping,
|
|
|
|
input_output_aliases=input_output_aliases,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
2024-03-06 09:15:36 -08:00
|
|
|
compiler_params=compiler_params,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate=batched_cost_estimate,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals=batched_out_avals,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend=backend,
|
2024-02-01 15:38:46 -08:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
return out, (0,) * len(out)
|
2024-06-05 08:14:39 -07:00
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
|
|
|
|
2024-06-17 17:14:58 -07:00
|
|
|
|
|
|
|
def checkify_pallas_kernel_body_jaxpr(
|
|
|
|
body_jaxpr: jax_core.ClosedJaxpr,
|
|
|
|
enabled_errors,
|
|
|
|
error: checkify.Error,
|
|
|
|
grid_mapping: GridMapping) -> tuple[
|
|
|
|
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
|
|
|
|
err_vals, err_tree = tree_util.tree_flatten(error)
|
2024-12-12 09:49:06 -08:00
|
|
|
err_vals = map(jax_core.get_aval, err_vals)
|
2024-06-17 17:14:58 -07:00
|
|
|
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
|
|
|
|
|
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
|
|
checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr(
|
|
|
|
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
|
|
|
return checked_jaxpr, out_tree, error_effects
|
|
|
|
|
2024-10-08 15:47:36 -07:00
|
|
|
def pallas_call_checkify_oob_grid(error: checkify.Error,
|
|
|
|
enabled_errors,
|
|
|
|
args: jax_core.Value,
|
|
|
|
grid_mapping: GridMapping,
|
|
|
|
input_output_aliases) -> checkify.Error:
|
|
|
|
if checkify.OOBError not in enabled_errors:
|
|
|
|
return error
|
|
|
|
dynamic_grid_args, args = split_list(
|
|
|
|
args, [grid_mapping.num_dynamic_grid_bounds]
|
|
|
|
)
|
2025-01-27 17:51:50 -08:00
|
|
|
output_args = hlo_interpreter._initialize_output_vals(grid_mapping.block_mappings_output,
|
2024-10-08 15:47:36 -07:00
|
|
|
args, input_output_aliases)
|
|
|
|
scalars, input_args, _ = split_list(
|
|
|
|
args, [grid_mapping.num_index_operands,
|
|
|
|
grid_mapping.num_inputs],
|
|
|
|
)
|
|
|
|
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
|
|
|
grid = tuple(
|
|
|
|
a if a is not pallas_core.dynamic_grid_dim
|
|
|
|
else next(dynamic_grid_args_iter)
|
|
|
|
for a in grid_mapping.grid
|
|
|
|
)
|
|
|
|
grid_start_indices = (jnp.int32(0),) * len(grid)
|
|
|
|
if grid:
|
2024-10-21 13:17:16 +01:00
|
|
|
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
|
2024-10-08 15:47:36 -07:00
|
|
|
else:
|
|
|
|
# Base case is always one iteration when grid is ()
|
|
|
|
num_iterations = 1
|
|
|
|
|
|
|
|
is_indexing_dim = [
|
|
|
|
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
|
|
|
for bm in grid_mapping.block_mappings
|
|
|
|
]
|
|
|
|
block_shapes = [
|
|
|
|
None if iid is None
|
|
|
|
else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
|
|
|
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
|
|
|
]
|
|
|
|
# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
|
|
|
|
# i:int32 is the interation index
|
|
|
|
# loop_idx: tuple[int32] are the program ids for each grid axis
|
|
|
|
def cond(carry):
|
|
|
|
i, *_ = carry
|
|
|
|
return i < num_iterations
|
|
|
|
def body(carry):
|
|
|
|
i, loop_idx = carry
|
|
|
|
if grid_mapping.local_grid_env is not None:
|
|
|
|
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
|
|
|
|
else:
|
|
|
|
local_grid_env = tuple(
|
|
|
|
pallas_core.GridAxis(idx, b)
|
|
|
|
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
|
|
|
if dim not in grid_mapping.vmapped_dims
|
|
|
|
)
|
|
|
|
with pallas_core.grid_env(local_grid_env):
|
|
|
|
start_indices = [
|
|
|
|
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
|
|
|
|
for bm in grid_mapping.block_mappings]
|
|
|
|
# We perform a dynamic slice on the i/o blocks, which will be checked by
|
|
|
|
# checkify for OOB accesses.
|
2025-01-27 17:51:50 -08:00
|
|
|
map(hlo_interpreter._dynamic_slice, start_indices, block_shapes,
|
2024-10-08 15:47:36 -07:00
|
|
|
[*input_args, *output_args], is_indexing_dim)
|
2025-01-27 17:51:50 -08:00
|
|
|
return (i + 1, hlo_interpreter._get_next_indices(grid, loop_idx))
|
2024-10-08 15:47:36 -07:00
|
|
|
def f(_):
|
|
|
|
lax.while_loop(
|
|
|
|
cond, body, (jnp.int32(0), grid_start_indices)
|
|
|
|
)
|
|
|
|
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
|
|
|
|
wrapped_loop, _ = api_util.flatten_fun_nokwargs(
|
2025-02-08 15:19:46 +02:00
|
|
|
lu.wrap_init(f,
|
|
|
|
debug_info=api_util.debug_info("checkify oob_grid_access",
|
|
|
|
f, (0,), {})),
|
|
|
|
jaxpr_in_tree)
|
2024-10-08 15:47:36 -07:00
|
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
|
|
avals_in = map(jax_core.get_aval, flat_args)
|
|
|
|
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
|
|
|
wrapped_loop, list(avals_in))
|
|
|
|
traced_loop = jax_core.ClosedJaxpr(traced_loop, consts)
|
|
|
|
out_error, _ = checkify.checkify_jaxpr(
|
|
|
|
traced_loop, checkify.index_checks, error, flat_args)
|
|
|
|
return out_error
|
|
|
|
|
2024-06-17 17:14:58 -07:00
|
|
|
def pallas_call_checkify_rule(error: checkify.Error,
|
|
|
|
enabled_errors,
|
|
|
|
*args: jax_core.Value,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
interpret: bool,
|
|
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
|
|
grid_mapping: GridMapping,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
2024-06-17 17:14:58 -07:00
|
|
|
**kwargs):
|
2024-10-08 15:47:36 -07:00
|
|
|
# Check for OOB accesses in the grid.
|
|
|
|
error = pallas_call_checkify_oob_grid(error, enabled_errors,
|
|
|
|
args, grid_mapping,
|
|
|
|
input_output_aliases)
|
2024-06-17 17:14:58 -07:00
|
|
|
# We implement the checkify rule in 4 steps:
|
|
|
|
# 1) First, trace the kernel body to get the expected error shapes.
|
|
|
|
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
|
|
|
|
# and outputs.
|
|
|
|
# 3) Create a new kernel which stores the errors in output memrefs instead of
|
|
|
|
# returning them, since pallas kernels do not return outputs.
|
|
|
|
# 4) Create block specs for the error state and call pallas_call with
|
|
|
|
# the new kernel.
|
|
|
|
dynamic_grid_bounds, scalars, args = split_list( # type: ignore
|
2024-07-23 15:25:14 +03:00
|
|
|
args, [grid_mapping.num_dynamic_grid_bounds,
|
|
|
|
grid_mapping.num_index_operands]
|
2024-06-17 17:14:58 -07:00
|
|
|
)
|
|
|
|
num_scalars = len(scalars)
|
|
|
|
num_kernel_inputs = len(args)
|
2024-07-23 15:25:14 +03:00
|
|
|
num_kernel_outputs = grid_mapping.num_outputs
|
2024-06-17 17:14:58 -07:00
|
|
|
|
|
|
|
# Trace the jaxpr to get an initial error value so the kernel jaxpr has all of
|
|
|
|
# the required inputs.
|
|
|
|
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
|
|
|
_jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr(
|
|
|
|
closed_jaxpr, enabled_errors, error, grid_mapping)
|
|
|
|
error = error._add_placeholder_effects(error_effects)
|
2024-07-16 18:12:19 -07:00
|
|
|
err_vals, err_in_tree = jax.tree.flatten(error)
|
2024-12-12 09:49:06 -08:00
|
|
|
shaped_err_avals = map(jax_core.get_aval, err_vals)
|
2024-06-17 17:14:58 -07:00
|
|
|
|
|
|
|
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
|
|
|
|
# all enabled errors removed, but have the error as inputs and return values.
|
|
|
|
input_avals = [v.aval for v in jaxpr.invars]
|
|
|
|
num_err_vals = len(err_vals)
|
2024-12-12 09:49:06 -08:00
|
|
|
shaped_input_avals = tuple(input_avals)
|
2024-06-17 17:14:58 -07:00
|
|
|
checkify_in_avals = [*shaped_err_avals,
|
|
|
|
*shaped_input_avals]
|
|
|
|
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
|
|
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
2024-07-16 18:12:19 -07:00
|
|
|
checked_jaxpr, error_out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
|
|
|
|
closed_kernel_jaxpr, enabled_errors, err_in_tree, *checkify_in_avals)
|
2024-06-17 17:14:58 -07:00
|
|
|
|
|
|
|
# Create a new kernel to remove the error as an return value and instead
|
|
|
|
# write them to a memref. This is because pallas kernels are expected
|
|
|
|
# to have no return values but instead write their outputs to a ref.
|
|
|
|
def checked_kernel_fn(*args):
|
2024-07-16 18:12:19 -07:00
|
|
|
(scalars, in_error_refs, inputs, out_error_refs, outputs, scratch
|
2024-06-17 17:14:58 -07:00
|
|
|
) = split_list(
|
|
|
|
args,
|
|
|
|
[num_scalars, num_err_vals,
|
|
|
|
num_kernel_inputs, num_err_vals, num_kernel_outputs])
|
2024-07-16 18:12:19 -07:00
|
|
|
# TODO(b/350593266): Remove zero-indexing once we support ()-shaped scalars.
|
|
|
|
input_error_vals = [err_ref[0, 0] for err_ref in in_error_refs]
|
2024-06-17 17:14:58 -07:00
|
|
|
# We need to re-order the inputs here. A checkified jaxpr always expects
|
|
|
|
# errors before other arguments.
|
|
|
|
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
|
|
|
|
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
|
2024-07-22 11:20:15 +00:00
|
|
|
result_flat = jax_core.eval_jaxpr(
|
2024-06-17 17:14:58 -07:00
|
|
|
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
|
|
|
|
output_errors, _ = split_list(result_flat, [num_err_vals])
|
|
|
|
# Store new errors back in the error refs.
|
2024-07-16 18:12:19 -07:00
|
|
|
for in_ref, out_ref, error in zip(
|
|
|
|
in_error_refs, out_error_refs, output_errors):
|
|
|
|
in_ref[0, 0] = error
|
|
|
|
out_ref[0, 0] = error
|
2024-06-17 17:14:58 -07:00
|
|
|
return []
|
|
|
|
|
|
|
|
# Trace the new checked_kernel_fn with Memref inputs so that
|
|
|
|
# we can replace the old kernel jaxpr with the new checked jaxpr in
|
|
|
|
# pallas_call.
|
2024-07-16 18:12:19 -07:00
|
|
|
|
|
|
|
# ensure_2d_shape is only necessary because pallas does not support
|
|
|
|
# ()-shaped Memrefs.
|
|
|
|
# TODO(b/350593266): Remove once we support ()-shaped scalars.
|
|
|
|
def _ensure_2d_error_shape(arg):
|
|
|
|
if isinstance(arg, jax_core.ShapedArray):
|
|
|
|
dtype = arg.dtype
|
|
|
|
return jax_core.ShapedArray((1, 1) + arg.shape, dtype=dtype,
|
|
|
|
weak_type=arg.weak_type)
|
|
|
|
elif isinstance(arg, jax.Array):
|
|
|
|
return jnp.reshape(arg, (1, 1) + arg.shape)
|
|
|
|
else:
|
|
|
|
return jnp.array([[arg]])
|
|
|
|
shaped_err_avals = map(_ensure_2d_error_shape, shaped_err_avals)
|
|
|
|
err_vals = map(_ensure_2d_error_shape, err_vals)
|
|
|
|
|
2024-06-17 17:14:58 -07:00
|
|
|
error_memref_aval = [pallas_core.AbstractMemoryRef(
|
2024-07-16 18:12:19 -07:00
|
|
|
err_val, pallas_core.MemorySpace.ERROR) for err_val in shaped_err_avals]
|
2024-06-17 17:14:58 -07:00
|
|
|
shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list(
|
|
|
|
shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs])
|
|
|
|
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
|
|
|
|
*error_memref_aval, *output_aval, *scratch_aval]
|
|
|
|
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
2025-01-31 22:23:20 +02:00
|
|
|
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
|
|
|
|
retrace_in_avals, {})
|
2025-01-24 10:57:28 +02:00
|
|
|
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
|
|
|
lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree)
|
|
|
|
|
2024-06-17 17:14:58 -07:00
|
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
|
|
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
2025-01-24 10:57:28 +02:00
|
|
|
wrapped_kernel_with_err, jaxpr_flat_avals)
|
2024-06-17 17:14:58 -07:00
|
|
|
|
|
|
|
# Prepare pallas_call inputs. We need to create new block specs
|
|
|
|
# for the new error inputs and outputs.
|
2024-07-25 01:49:59 -07:00
|
|
|
error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals)
|
2024-07-02 00:40:13 -07:00
|
|
|
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
|
2024-07-18 15:33:40 +02:00
|
|
|
error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths)
|
2024-06-17 17:14:58 -07:00
|
|
|
error_block_mappings = map(
|
|
|
|
partial(
|
|
|
|
pallas_core._convert_block_spec_to_block_mapping,
|
2024-07-25 01:49:59 -07:00
|
|
|
index_map_avals=grid_mapping.index_map_avals,
|
|
|
|
index_map_tree=grid_mapping.index_map_tree,
|
2024-06-17 17:14:58 -07:00
|
|
|
grid=grid_mapping.grid,
|
2024-07-18 15:33:40 +02:00
|
|
|
mapped_dims=grid_mapping.vmapped_dims),
|
|
|
|
error_block_specs, error_origins, shaped_err_avals)
|
2024-06-17 17:14:58 -07:00
|
|
|
input_block_mappings, output_block_mappings = split_list(
|
|
|
|
grid_mapping.block_mappings, [num_kernel_inputs,])
|
|
|
|
grid_mapping_with_error = grid_mapping.replace(
|
|
|
|
block_mappings=(*error_block_mappings, *input_block_mappings,
|
2024-07-25 01:49:59 -07:00
|
|
|
*error_block_mappings, *output_block_mappings),
|
|
|
|
num_inputs=grid_mapping.num_inputs + len(error_block_mappings),
|
|
|
|
num_outputs=grid_mapping.num_outputs + len(error_block_mappings)
|
2024-06-17 17:14:58 -07:00
|
|
|
)
|
|
|
|
# Bump all input_output_aliases by num_err_vals to make room for error
|
|
|
|
# TODO(justinfu): Don't bump scalars here.
|
|
|
|
input_output_aliases = tuple(
|
|
|
|
(i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases)
|
|
|
|
input_output_aliases_with_error = tuple(
|
|
|
|
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
|
|
|
|
|
|
|
|
new_vals_in = [*scalars, *err_vals, *args]
|
2024-09-18 20:38:54 -07:00
|
|
|
new_out_avals = (*shaped_err_avals, *out_avals)
|
2024-06-17 17:14:58 -07:00
|
|
|
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
|
|
|
|
jaxpr=final_jaxpr,
|
|
|
|
interpret=interpret,
|
|
|
|
grid_mapping=grid_mapping_with_error,
|
|
|
|
input_output_aliases=input_output_aliases_with_error,
|
2024-09-18 20:38:54 -07:00
|
|
|
out_avals=new_out_avals,
|
2024-06-17 17:14:58 -07:00
|
|
|
**kwargs)
|
|
|
|
errors, results = split_list(result, [num_err_vals])
|
2024-07-16 18:12:19 -07:00
|
|
|
# TODO(b/350593266): Remove line below once we support ()-shaped scalars.
|
|
|
|
errors = [err_val[0, 0] for err_val in errors]
|
|
|
|
new_error, _ = jax.tree.unflatten(error_out_tree, errors)
|
2024-06-17 17:14:58 -07:00
|
|
|
return new_error, results
|
|
|
|
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
|
|
|
|
2024-10-01 03:30:15 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
@weakref_lru_cache
|
2024-10-01 03:30:15 -07:00
|
|
|
def _trace_kernel_to_jaxpr(
|
|
|
|
fun: Callable,
|
|
|
|
name_and_src_info: pallas_core.NameAndSrcInfo,
|
|
|
|
grid_mapping: GridMapping,
|
|
|
|
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
|
|
|
|
kernel_in_tree: tree_util.PyTreeDef,
|
|
|
|
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
|
2024-10-14 14:00:58 -07:00
|
|
|
indexer: bool = False,
|
2024-10-01 16:29:59 -07:00
|
|
|
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
|
2025-01-24 10:57:28 +02:00
|
|
|
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
|
|
|
|
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
|
2024-07-25 01:49:59 -07:00
|
|
|
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
2025-01-24 10:57:28 +02:00
|
|
|
lu.wrap_init(fun, debug_info=debug), kernel_in_tree)
|
2024-10-22 01:38:04 -07:00
|
|
|
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
|
|
|
wrapped_kernel_fun, kernel_in_transforms
|
|
|
|
)
|
2024-07-22 23:24:31 -07:00
|
|
|
with grid_mapping.trace_env():
|
2024-07-25 01:49:59 -07:00
|
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
2025-01-24 10:57:28 +02:00
|
|
|
kernel_avals)
|
2024-06-17 15:17:52 -07:00
|
|
|
if consts:
|
2024-12-12 09:49:06 -08:00
|
|
|
consts_avals = [jax_core.get_aval(c) for c in consts]
|
2024-10-01 16:29:59 -07:00
|
|
|
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
|
|
|
|
raise ValueError(
|
|
|
|
f"The kernel function in the pallas_call {name_and_src_info} "
|
|
|
|
f"captures constants {consts_avals}. "
|
|
|
|
"You should pass them as inputs")
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
|
2024-07-25 01:49:59 -07:00
|
|
|
kernel_out_tree = out_tree_thunk()
|
2024-10-14 14:00:58 -07:00
|
|
|
if not indexer and kernel_out_tree != tree_util.tree_structure(None):
|
2024-07-25 01:49:59 -07:00
|
|
|
raise ValueError(
|
2024-08-05 04:23:15 -07:00
|
|
|
f"The kernel function in the pallas_call {name_and_src_info} "
|
|
|
|
f"should return None. It returns a PyTree: {kernel_out_tree}")
|
2024-10-01 16:29:59 -07:00
|
|
|
return jaxpr, tuple(consts)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2023-09-07 17:08:18 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
|
2024-05-14 14:47:24 -07:00
|
|
|
"jax_pallas_use_mosaic_gpu",
|
|
|
|
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
|
|
|
|
help=(
|
|
|
|
"If True, lower Pallas kernels to the experimental Mosaic GPU"
|
2025-02-12 01:47:12 -08:00
|
|
|
" dialect, instead of Triton IR."
|
2024-05-14 14:47:24 -07:00
|
|
|
),
|
|
|
|
)
|
2025-02-12 01:47:12 -08:00
|
|
|
|
|
|
|
|
2024-10-30 10:12:47 -07:00
|
|
|
_PALLAS_VERBOSE_ERRORS = config.bool_flag(
|
|
|
|
"jax_pallas_verbose_errors",
|
|
|
|
default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True),
|
|
|
|
help=(
|
|
|
|
"If True, print verbose error messages for Pallas kernels."
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _verbose_errors_enabled() -> bool:
|
|
|
|
return _PALLAS_VERBOSE_ERRORS.value
|
2024-05-14 14:47:24 -07:00
|
|
|
|
|
|
|
|
2024-04-02 07:35:40 -07:00
|
|
|
def _unsupported_lowering_error(platform: str) -> Exception:
|
|
|
|
return ValueError(
|
|
|
|
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
|
|
|
|
" install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install"
|
|
|
|
" jaxlib TPU and libtpu. See"
|
|
|
|
" https://jax.readthedocs.io/en/latest/installation.html."
|
|
|
|
)
|
|
|
|
|
2024-10-22 05:36:37 -07:00
|
|
|
_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"]
|
|
|
|
|
2024-04-02 07:35:40 -07:00
|
|
|
|
2024-04-01 14:39:54 -07:00
|
|
|
def _pallas_call_lowering(
|
2024-10-22 05:36:37 -07:00
|
|
|
ctx: mlir.LoweringRuleContext,
|
|
|
|
*in_nodes,
|
|
|
|
interpret: bool,
|
|
|
|
backend: _Backend | None,
|
|
|
|
**params,
|
2024-04-01 14:39:54 -07:00
|
|
|
):
|
2024-10-01 16:29:59 -07:00
|
|
|
if params['jaxpr'].constvars:
|
|
|
|
raise ValueError('Cannot lower a pallas_call with constants.')
|
2023-11-16 12:46:06 -08:00
|
|
|
if interpret:
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams):
|
|
|
|
impl = partial(mosaic_tpu_interpret.interpret_pallas_call, **params)
|
|
|
|
else:
|
|
|
|
impl = partial(hlo_interpreter.pallas_call_hlo_interpret,
|
|
|
|
backend=backend,
|
|
|
|
**params)
|
2023-11-16 12:46:06 -08:00
|
|
|
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
2024-04-01 14:39:54 -07:00
|
|
|
|
2024-06-10 16:54:35 +02:00
|
|
|
def cpu_lowering(ctx: mlir.LoweringRuleContext,
|
|
|
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
|
|
|
**params):
|
2023-11-08 16:35:03 -08:00
|
|
|
raise ValueError("Only interpret mode is supported on CPU backend.")
|
2024-06-10 16:54:35 +02:00
|
|
|
|
|
|
|
def tpu_lowering(ctx: mlir.LoweringRuleContext,
|
|
|
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
|
|
|
**params):
|
2024-10-22 05:36:37 -07:00
|
|
|
if backend and backend != "mosaic_tpu":
|
|
|
|
raise ValueError("Only mosaic backend supported for TPU")
|
2024-07-17 05:28:34 -07:00
|
|
|
if mosaic_tpu_backend is None:
|
2024-06-10 16:54:35 +02:00
|
|
|
raise _unsupported_lowering_error("tpu")
|
2024-07-17 05:28:34 -07:00
|
|
|
return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
|
|
|
|
ctx, *in_nodes, **params
|
|
|
|
)
|
2024-06-10 16:54:35 +02:00
|
|
|
|
|
|
|
def gpu_lowering(ctx: mlir.LoweringRuleContext,
|
|
|
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
|
|
|
**params):
|
2024-04-02 14:36:00 -07:00
|
|
|
try:
|
2024-10-22 05:36:37 -07:00
|
|
|
match backend:
|
|
|
|
case "mosaic_gpu":
|
|
|
|
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
|
|
|
case "triton":
|
|
|
|
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
|
|
|
case None:
|
|
|
|
if _PALLAS_USE_MOSAIC_GPU.value:
|
|
|
|
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
|
|
|
else:
|
|
|
|
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
|
|
|
case _:
|
|
|
|
raise ValueError(f"Unsupported backend: {backend}")
|
|
|
|
except ImportError as e:
|
2024-06-10 16:54:35 +02:00
|
|
|
raise _unsupported_lowering_error("gpu")
|
2024-10-22 05:36:37 -07:00
|
|
|
|
2024-07-17 05:28:34 -07:00
|
|
|
return pallas_call_registration.pallas_call_lowering(
|
|
|
|
ctx, *in_nodes, **params
|
|
|
|
)
|
2024-04-01 14:39:54 -07:00
|
|
|
|
2024-06-10 16:54:35 +02:00
|
|
|
return mlir.lower_per_platform(ctx, "pallas_call",
|
|
|
|
dict(cpu=cpu_lowering,
|
|
|
|
tpu=tpu_lowering,
|
|
|
|
cuda=gpu_lowering,
|
|
|
|
rocm=gpu_lowering),
|
|
|
|
None, # default_rule
|
|
|
|
effects.no_effects,
|
|
|
|
*in_nodes,
|
|
|
|
interpret=interpret,
|
|
|
|
**params)
|
2024-04-01 14:39:54 -07:00
|
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
2023-11-08 16:35:03 -08:00
|
|
|
|
|
|
|
|
2024-07-12 16:02:48 -07:00
|
|
|
def _pallas_custom_str_eqn_compact(
|
|
|
|
prim: jax_core.Primitive, params: dict[Any, Any]
|
|
|
|
) -> str:
|
|
|
|
del prim, params
|
|
|
|
# Hide most info from compact str representation
|
|
|
|
return "pallas_call"
|
|
|
|
jax_core.custom_str_eqn_compact_rules[pallas_call_p] = (
|
|
|
|
_pallas_custom_str_eqn_compact
|
|
|
|
)
|
|
|
|
|
2024-07-22 23:24:31 -07:00
|
|
|
def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params):
|
|
|
|
with grid_mapping.trace_env():
|
|
|
|
return pallas_call_p.abstract_eval(
|
|
|
|
*in_avals, grid_mapping=grid_mapping, **params
|
|
|
|
)
|
|
|
|
jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule
|
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
|
|
|
|
match out_shape:
|
|
|
|
case jax.ShapeDtypeStruct():
|
|
|
|
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
|
|
|
|
case pallas_core.MemoryRef():
|
|
|
|
return out_shape.get_array_aval()
|
|
|
|
case _:
|
|
|
|
if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")):
|
|
|
|
raise ValueError(f"Invalid out_shape type: {type(out_shape)}")
|
|
|
|
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
|
|
|
|
|
2024-07-12 16:02:48 -07:00
|
|
|
|
2024-10-01 16:29:59 -07:00
|
|
|
@state_discharge.register_discharge_rule(pallas_call_p)
|
|
|
|
def _pallas_call_state_discharge_rule(
|
|
|
|
avals_in,
|
|
|
|
avals_out,
|
|
|
|
*args,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
|
|
name_and_src_info: pallas_core.NameAndSrcInfo,
|
|
|
|
grid_mapping: GridMapping,
|
|
|
|
debug: bool,
|
|
|
|
interpret: bool,
|
|
|
|
compiler_params: Any,
|
|
|
|
cost_estimate: CostEstimate | None,
|
|
|
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
2024-10-22 05:36:37 -07:00
|
|
|
backend: _Backend | None = None
|
2024-10-01 16:29:59 -07:00
|
|
|
):
|
|
|
|
del avals_out
|
|
|
|
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
|
|
|
num_refs = len(jaxpr.constvars)
|
|
|
|
ref_avals, rest_in_avals = split_list(avals_in, [num_refs])
|
|
|
|
assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals)
|
|
|
|
ref_avals = [
|
|
|
|
pallas_core.AbstractMemoryRef(
|
|
|
|
ref_aval.inner_aval, pallas_core.MemorySpace.ANY
|
|
|
|
)
|
|
|
|
for ref_aval in ref_avals
|
|
|
|
]
|
|
|
|
ref_block_specs = [
|
|
|
|
pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)
|
|
|
|
] * num_refs
|
|
|
|
ref_block_mappings = [
|
|
|
|
block_spec.to_block_mapping(
|
|
|
|
origin="", # TODO(sharadmv): enable origins for refs
|
|
|
|
array_aval=ref_aval.inner_aval,
|
|
|
|
index_map_avals=grid_mapping.index_map_avals,
|
|
|
|
index_map_tree=grid_mapping.index_map_tree,
|
|
|
|
grid=grid_mapping.grid,
|
|
|
|
mapped_dims=grid_mapping.mapped_dims,
|
|
|
|
) for ref_aval, block_spec in zip(ref_avals, ref_block_specs)
|
|
|
|
]
|
|
|
|
in_block_mappings, out_block_mappings = split_list(
|
|
|
|
grid_mapping.block_mappings, [grid_mapping.num_inputs]
|
|
|
|
)
|
|
|
|
new_block_mappings = (
|
|
|
|
*ref_block_mappings,
|
|
|
|
*in_block_mappings,
|
|
|
|
*ref_block_mappings,
|
|
|
|
*out_block_mappings,
|
|
|
|
)
|
|
|
|
new_grid_mapping = grid_mapping.replace(
|
|
|
|
block_mappings=new_block_mappings,
|
|
|
|
num_inputs=grid_mapping.num_inputs + num_refs,
|
|
|
|
num_outputs=grid_mapping.num_outputs + num_refs)
|
|
|
|
new_input_output_aliases = [
|
|
|
|
(i + grid_mapping.num_index_operands, i) for i in range(num_refs)
|
|
|
|
]
|
|
|
|
for i, o in input_output_aliases:
|
|
|
|
new_input_output_aliases.append((i + num_refs, o + num_refs))
|
|
|
|
ref_out_avals = [ref_aval.inner_aval for ref_aval in ref_avals]
|
|
|
|
new_out_avals = (*ref_out_avals, *out_avals)
|
|
|
|
ref_args, dynamic_grid_bounds, index_operands, rest_args = split_list(
|
|
|
|
args,
|
|
|
|
[
|
|
|
|
num_refs,
|
|
|
|
grid_mapping.num_dynamic_grid_bounds,
|
|
|
|
grid_mapping.num_index_operands,
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def _rewritten_body(*args):
|
|
|
|
index_args, in_args, out_args, rest_args = split_list(
|
|
|
|
args, [new_grid_mapping.num_index_operands, new_grid_mapping.num_inputs,
|
|
|
|
new_grid_mapping.num_outputs])
|
|
|
|
ref_in_args, in_args = split_list(in_args, [num_refs])
|
|
|
|
ref_out_args, out_args = split_list(out_args, [num_refs])
|
|
|
|
# We don't care about ref_out_args because they are aliased to ref_in_args
|
|
|
|
del ref_out_args
|
|
|
|
jax_core.eval_jaxpr(
|
|
|
|
jaxpr, ref_in_args, *index_args, *in_args, *out_args, *rest_args
|
|
|
|
)
|
|
|
|
return []
|
|
|
|
index_map_avals, jaxpr_in_avals, jaxpr_out_avals, jaxpr_rest_avals = (
|
|
|
|
split_list(
|
|
|
|
[v.aval for v in jaxpr.invars],
|
|
|
|
[
|
|
|
|
grid_mapping.num_index_operands,
|
|
|
|
grid_mapping.num_inputs,
|
|
|
|
grid_mapping.num_outputs,
|
|
|
|
],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
|
2025-02-08 15:19:46 +02:00
|
|
|
lu.wrap_init(_rewritten_body, debug_info=jaxpr.debug_info),
|
2024-10-01 16:29:59 -07:00
|
|
|
[
|
|
|
|
*index_map_avals,
|
|
|
|
*ref_avals,
|
|
|
|
*jaxpr_in_avals,
|
|
|
|
*ref_avals,
|
|
|
|
*jaxpr_out_avals,
|
|
|
|
*jaxpr_rest_avals,
|
|
|
|
],
|
|
|
|
)
|
|
|
|
out_flat = pallas_call_p.bind(
|
|
|
|
*consts,
|
|
|
|
*dynamic_grid_bounds,
|
|
|
|
*index_operands,
|
|
|
|
*ref_args,
|
|
|
|
*rest_args,
|
|
|
|
jaxpr=new_jaxpr,
|
|
|
|
input_output_aliases=new_input_output_aliases,
|
|
|
|
grid_mapping=new_grid_mapping,
|
|
|
|
name_and_src_info=name_and_src_info,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
|
|
|
compiler_params=compiler_params,
|
|
|
|
cost_estimate=cost_estimate,
|
|
|
|
out_avals=new_out_avals,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend=backend,
|
2024-10-01 16:29:59 -07:00
|
|
|
)
|
|
|
|
refs_out, rest = split_list(out_flat, [num_refs])
|
|
|
|
updated_vals_in = refs_out + [None] * len(rest_in_avals)
|
|
|
|
return updated_vals_in, rest
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def pallas_call(
|
2024-07-18 15:33:40 +02:00
|
|
|
kernel: Callable[..., None],
|
2023-09-07 17:08:18 -07:00
|
|
|
out_shape: Any,
|
|
|
|
*,
|
2023-08-01 16:42:26 -07:00
|
|
|
grid_spec: GridSpec | None = None,
|
2024-07-23 15:25:14 +03:00
|
|
|
grid: TupleGrid = (),
|
2024-06-07 12:07:07 +01:00
|
|
|
in_specs: BlockSpecTree = no_block_spec,
|
|
|
|
out_specs: BlockSpecTree = no_block_spec,
|
2024-09-18 05:25:37 -07:00
|
|
|
scratch_shapes: ScratchShapeTree = (),
|
2023-12-08 12:09:04 +00:00
|
|
|
input_output_aliases: dict[int, int] = {},
|
2024-07-23 15:25:14 +03:00
|
|
|
debug: bool = False,
|
2023-08-01 16:42:26 -07:00
|
|
|
interpret: bool = False,
|
|
|
|
name: str | None = None,
|
2024-08-20 15:38:03 -07:00
|
|
|
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
|
2024-08-05 08:17:18 -07:00
|
|
|
cost_estimate: CostEstimate | None = None,
|
2024-10-22 05:36:37 -07:00
|
|
|
backend: _Backend | None = None,
|
2024-06-07 12:07:07 +01:00
|
|
|
) -> Callable[..., Any]:
|
2024-06-27 11:07:26 +03:00
|
|
|
"""Invokes a Pallas kernel on some inputs.
|
|
|
|
|
|
|
|
See `Pallas Quickstart <https://jax.readthedocs.io/en/latest/pallas/quickstart.html>`_.
|
|
|
|
|
|
|
|
Args:
|
2024-07-18 15:33:40 +02:00
|
|
|
kernel: the kernel function, that receives a Ref for each input and output.
|
2024-06-27 11:07:26 +03:00
|
|
|
The shape of the Refs are given by the ``block_shape`` in the
|
|
|
|
corresponding ``in_specs`` and ``out_specs``.
|
|
|
|
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
|
|
|
|
and dtypes of the outputs.
|
2024-09-18 05:25:37 -07:00
|
|
|
grid_spec: An alternative way to specify ``grid``, ``in_specs``,
|
|
|
|
``out_specs`` and ``scratch_shapes``. If given, those other parameters
|
|
|
|
must not be also given.
|
2024-06-27 11:07:26 +03:00
|
|
|
grid: the iteration space, as a tuple of integers. The kernel is executed
|
2024-07-25 01:49:59 -07:00
|
|
|
as many times as ``prod(grid)``.
|
2024-06-27 11:07:26 +03:00
|
|
|
See details at :ref:`pallas_grid`.
|
|
|
|
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
|
|
|
a structure matching that of the positional arguments.
|
2024-07-09 12:43:37 +03:00
|
|
|
The default value for ``in_specs`` specifies the whole array for all
|
|
|
|
inputs, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
2024-06-27 11:07:26 +03:00
|
|
|
See details at :ref:`pallas_blockspec`.
|
|
|
|
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
|
|
|
a structure matching that of the outputs.
|
2024-07-09 12:43:37 +03:00
|
|
|
The default value for ``out_specs`` specifies the whole array,
|
|
|
|
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
2024-06-27 11:07:26 +03:00
|
|
|
See details at :ref:`pallas_blockspec`.
|
2024-09-18 05:25:37 -07:00
|
|
|
scratch_shapes: a PyTree of backend-specific temporary objects required
|
|
|
|
by the kernel, such as temporary buffers, synchronization primitives,
|
|
|
|
etc.
|
2024-06-27 11:07:26 +03:00
|
|
|
input_output_aliases: a dictionary mapping the index of some inputs to
|
2024-07-08 09:13:35 +03:00
|
|
|
the index of the output that aliases them. These indices are in the
|
|
|
|
flattened inputs and outputs.
|
2024-07-23 15:25:14 +03:00
|
|
|
debug: if True, Pallas prints various intermediate forms of the kernel
|
|
|
|
as it is being processed.
|
2024-06-27 11:07:26 +03:00
|
|
|
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
|
|
|
|
grid whose body is the kernel lowered as a JAX function. This does not
|
|
|
|
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
|
|
|
|
This is useful for debugging.
|
2024-08-05 04:23:15 -07:00
|
|
|
name: if present, specifies the name to use for this kernel call in
|
|
|
|
debugging and error messages. To this name we append the file and line
|
|
|
|
where the kernel function is defined, .e.g:
|
|
|
|
`{name} for kernel function {kernel_name} at {file}:{line}`.
|
|
|
|
If missing, then we use `{kernel_name} at {file}:{line}`.
|
2024-08-20 15:38:03 -07:00
|
|
|
compiler_params: Optional compiler parameters. If a dict is provided, it
|
|
|
|
should be of the form {platform: {param_name: param_value}}, where
|
2024-09-04 13:31:35 -07:00
|
|
|
platform is either 'mosaic' or 'triton'. It is also possible
|
|
|
|
to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
|
|
|
|
`jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
|
2024-10-22 05:36:37 -07:00
|
|
|
backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu"
|
|
|
|
determining the backend to be used. None means let pallas decide.
|
2024-09-04 13:31:35 -07:00
|
|
|
|
2024-06-27 11:07:26 +03:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A function that can be called on a number of positional array arguments to
|
|
|
|
invoke the Pallas kernel.
|
|
|
|
|
|
|
|
"""
|
2024-08-05 04:23:15 -07:00
|
|
|
kernel_src_info = api_util.fun_sourceinfo(kernel)
|
|
|
|
name_and_src_info = pallas_core.NameAndSrcInfo.from_pallas_call(
|
|
|
|
name, kernel_src_info)
|
2024-03-06 09:15:36 -08:00
|
|
|
if compiler_params is None:
|
|
|
|
compiler_params = {}
|
2024-08-20 15:38:03 -07:00
|
|
|
if isinstance(compiler_params, pallas_core.CompilerParams):
|
2024-09-04 12:47:57 -07:00
|
|
|
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
2024-08-20 15:38:03 -07:00
|
|
|
raise ValueError(
|
|
|
|
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
|
|
|
|
compiler_params = {
|
|
|
|
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
|
|
|
}
|
2024-07-25 01:49:59 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
if grid_spec is None:
|
2024-09-18 05:25:37 -07:00
|
|
|
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
|
2024-07-25 01:49:59 -07:00
|
|
|
else:
|
|
|
|
if grid:
|
|
|
|
raise ValueError(
|
|
|
|
"If `grid_spec` is specified, then `grid` must "
|
|
|
|
f"be `()`. It is {grid}")
|
|
|
|
if in_specs is not no_block_spec:
|
|
|
|
raise ValueError(
|
|
|
|
"If `grid_spec` is specified, then `in_specs` must "
|
|
|
|
f"be `no_block_spec`. It is {in_specs}")
|
|
|
|
if out_specs is not no_block_spec:
|
|
|
|
raise ValueError(
|
|
|
|
"If `grid_spec` is specified, then `out_specs` must "
|
|
|
|
f"be `no_block_spec`. It is {out_specs}")
|
2024-09-18 05:25:37 -07:00
|
|
|
if scratch_shapes:
|
|
|
|
raise ValueError(
|
|
|
|
"If `grid_spec` is specified, then `scratch_shapes` must "
|
|
|
|
f"be `()`. It is {scratch_shapes}")
|
2024-07-25 01:49:59 -07:00
|
|
|
del grid, in_specs, out_specs
|
2024-07-23 15:25:14 +03:00
|
|
|
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
2024-07-09 12:43:37 +03:00
|
|
|
# TODO(necula): this canonicalization may be convenient for some usage
|
|
|
|
# but it is lossy, because it prevents expressing functions that return
|
|
|
|
# lists.
|
2023-09-07 17:08:18 -07:00
|
|
|
if isinstance(out_shape, list):
|
2023-08-01 16:42:26 -07:00
|
|
|
out_shape = tuple(out_shape)
|
2024-07-02 00:40:13 -07:00
|
|
|
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
|
|
|
|
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
@partial(jax.jit, inline=True)
|
2023-08-01 16:42:26 -07:00
|
|
|
def wrapped(*args):
|
2024-07-02 00:40:13 -07:00
|
|
|
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
|
|
|
|
in_paths, flat_args = unzip2(flat_args_with_paths)
|
2024-12-12 09:49:06 -08:00
|
|
|
flat_in_avals = tuple(jax_core.get_aval(a) for a in flat_args)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
flat_out_avals = tuple(_convert_out_shape_to_aval(v)
|
2023-09-07 17:08:18 -07:00
|
|
|
for v in flat_out_shapes)
|
2024-07-18 15:33:40 +02:00
|
|
|
|
|
|
|
kernel_fun_sig = api_util.fun_signature(kernel)
|
|
|
|
arg_names = None
|
|
|
|
if kernel_fun_sig:
|
2025-01-31 22:23:20 +02:00
|
|
|
kernel_debug_info = api_util.debug_info(
|
2024-07-18 15:33:40 +02:00
|
|
|
"pallas_call kernel",
|
2025-01-15 21:36:38 +00:00
|
|
|
kernel,
|
|
|
|
[1] * len(kernel_fun_sig.parameters), {})
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
arg_names = kernel_debug_info.arg_names
|
2024-08-05 04:23:15 -07:00
|
|
|
del kernel_debug_info
|
2024-07-18 15:33:40 +02:00
|
|
|
in_origins = tuple(in_path_to_input_origin(p, arg_names)
|
|
|
|
for p in in_paths)
|
|
|
|
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
|
2024-07-25 01:49:59 -07:00
|
|
|
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
|
2024-10-01 03:30:15 -07:00
|
|
|
kernel_args, grid_mapping = pallas_core.get_grid_mapping(
|
2024-07-23 15:25:14 +03:00
|
|
|
grid_spec,
|
2024-07-18 15:33:40 +02:00
|
|
|
flat_in_avals, in_tree, in_origins,
|
|
|
|
flat_out_avals, out_tree, out_origins)
|
2024-10-01 03:30:15 -07:00
|
|
|
flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args)
|
|
|
|
flat_kernel_avals = tuple(
|
|
|
|
x.ref if isinstance(x, state_types.TransformedRef) else x
|
|
|
|
for x in flat_kernel_args
|
|
|
|
)
|
|
|
|
# Note that only a subset of all transforms can be found here, and they are
|
|
|
|
# never expected to contains any arrays.
|
|
|
|
kernel_arg_transforms = tuple(
|
|
|
|
x.transforms if isinstance(x, state_types.TransformedRef) else ()
|
|
|
|
for x in flat_kernel_args
|
|
|
|
)
|
2025-01-27 17:51:50 -08:00
|
|
|
jaxpr, consts = _trace_kernel_to_jaxpr(
|
|
|
|
kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals),
|
|
|
|
kernel_in_tree, kernel_arg_transforms)
|
2024-07-08 09:13:35 +03:00
|
|
|
for i_idx, o_idx in input_output_aliases.items():
|
|
|
|
if i_idx not in range(len(flat_in_avals)):
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
|
|
|
|
f"input index {i_idx} outside the range "
|
|
|
|
f"[0, {len(flat_in_avals)})")
|
|
|
|
if o_idx not in range(len(flat_out_avals)):
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
|
|
|
|
f"output index {o_idx} outside the range "
|
|
|
|
f"[0, {len(flat_out_avals)})")
|
|
|
|
in_aval = flat_in_avals[i_idx]
|
|
|
|
out_aval = flat_out_avals[o_idx]
|
2024-10-14 14:00:58 -07:00
|
|
|
if isinstance(in_aval, jax_core.DShapedArray):
|
|
|
|
new_shape = []
|
|
|
|
for d in in_aval.shape:
|
|
|
|
if isinstance(d, int):
|
|
|
|
new_shape.append(d)
|
|
|
|
else:
|
|
|
|
new_shape.append(d.dtype.bound)
|
|
|
|
|
|
|
|
in_aval = jax_core.ShapedArray(
|
|
|
|
tuple(new_shape), in_aval.dtype, in_aval.weak_type
|
|
|
|
)
|
|
|
|
|
2024-07-08 09:13:35 +03:00
|
|
|
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
|
|
f"referring to input{tree_util.keystr(in_paths[i_idx])} with "
|
|
|
|
f"abstract value {in_aval} "
|
|
|
|
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
|
|
|
|
f"a different abstract value {out_aval}.")
|
|
|
|
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
|
2025-01-27 17:51:50 -08:00
|
|
|
out_flat = pallas_call_p.bind(
|
|
|
|
*consts,
|
|
|
|
*dynamic_grid_bounds,
|
|
|
|
*index_args,
|
|
|
|
*rest_args,
|
|
|
|
out_avals=flat_out_avals,
|
|
|
|
jaxpr=jaxpr,
|
|
|
|
name_and_src_info=name_and_src_info,
|
|
|
|
debug=debug,
|
|
|
|
interpret=interpret,
|
|
|
|
grid_mapping=grid_mapping,
|
|
|
|
input_output_aliases=tuple(input_output_aliases.items()),
|
|
|
|
compiler_params=compiler_params,
|
|
|
|
cost_estimate=cost_estimate,
|
|
|
|
backend=backend,
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
out = tree_util.tree_unflatten(out_tree, out_flat)
|
|
|
|
return out
|
|
|
|
return wrapped
|
2024-07-17 05:28:34 -07:00
|
|
|
|
|
|
|
|
2024-08-20 15:06:27 -07:00
|
|
|
def in_path_to_input_origin(
|
|
|
|
in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None
|
|
|
|
) -> pallas_core.OriginStr:
|
2024-07-18 15:33:40 +02:00
|
|
|
"""Converts `args[k]<rest>` into `arg_k_name<rest>`."""
|
|
|
|
if arg_names is None:
|
|
|
|
return f"args{tree_util.keystr(in_path)}"
|
|
|
|
if len(in_path) == 0:
|
|
|
|
return "args"
|
|
|
|
arg_idx, *rest_path = in_path
|
2024-08-20 15:06:27 -07:00
|
|
|
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(
|
|
|
|
arg_names
|
|
|
|
):
|
2025-01-15 21:36:38 +00:00
|
|
|
if arg_names[arg_idx.idx] is None:
|
|
|
|
# TODO(necula): when is this needed?
|
|
|
|
# Repro: pallas_test:test_with_input_output_aliasing
|
|
|
|
return f"args{tree_util.keystr(in_path)}"
|
2024-07-18 15:33:40 +02:00
|
|
|
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
|
|
|
|
else:
|
|
|
|
return f"args{tree_util.keystr(tuple(in_path))}"
|
|
|
|
|
|
|
|
|
2024-07-17 05:28:34 -07:00
|
|
|
# We import the TPU backend at the top level because it defines flags. Note that
|
|
|
|
# we can only do that at the bottom of this file, beacuse it also depends on
|
|
|
|
# this module already being initialized.
|
|
|
|
|
|
|
|
try:
|
|
|
|
from jax._src.pallas.mosaic import pallas_call_registration as mosaic_tpu_backend
|
|
|
|
except ImportError:
|
|
|
|
mosaic_tpu_backend = None # type: ignore
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
|
|
|
|
try:
|
|
|
|
from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret
|
|
|
|
except ImportError:
|
|
|
|
mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore
|
|
|
|
TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)),
|
|
|
|
)
|