mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
1048 lines
41 KiB
Python
1048 lines
41 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Module for calling pallas functions from JAX."""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
from functools import partial, reduce
|
|
import itertools
|
|
from typing import Any
|
|
|
|
import jax
|
|
from jax import api_util
|
|
from jax import lax
|
|
from jax import tree_util
|
|
from jax._src import ad_util
|
|
from jax._src import checkify
|
|
from jax._src import config
|
|
from jax._src import core as jax_core
|
|
from jax._src import effects
|
|
from jax._src import linear_util as lu
|
|
from jax._src import state
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import xla
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.pallas.primitives import uninitialized_value
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.state import primitives as sp
|
|
from jax._src.util import (
|
|
safe_map,
|
|
safe_zip,
|
|
split_list,
|
|
tuple_insert,
|
|
weakref_lru_cache,
|
|
)
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
Grid = pallas_core.Grid
|
|
GridSpec = pallas_core.GridSpec
|
|
BlockMapping = pallas_core.BlockMapping
|
|
GridMapping = pallas_core.GridMapping
|
|
BlockSpec = pallas_core.BlockSpec
|
|
BlockSpecTree = pallas_core.BlockSpecTree
|
|
NoBlockSpec = pallas_core.NoBlockSpec
|
|
no_block_spec = pallas_core.no_block_spec
|
|
|
|
pallas_call_p = jax_core.Primitive('pallas_call')
|
|
pallas_call_p.multiple_results = True
|
|
|
|
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
|
if start_idx is None:
|
|
assert is_indexing is None
|
|
return value
|
|
assert is_indexing is not None
|
|
start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx)
|
|
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
|
|
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
|
|
dtype=np.bool_)])
|
|
return lax.squeeze(output, squeeze_dims)
|
|
|
|
def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
|
|
is_indexing):
|
|
if start_idx is None:
|
|
assert is_indexing is None
|
|
return update
|
|
assert is_indexing is not None
|
|
start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx)
|
|
broadcast_dims = tuple(i for i, b in enumerate(is_indexing)
|
|
if not b)
|
|
update = lax.broadcast_in_dim(update, block_shape, broadcast_dims)
|
|
assert update.shape == block_shape
|
|
return lax.dynamic_update_slice(value, update, start_idx)
|
|
|
|
def _pad_values_to_block_dimension(value,
|
|
block_shape):
|
|
"""Pads values so the shape evenly divides into block dimensions.
|
|
|
|
For example, if values has a shape of (33, 2, 5) with a block_shape of
|
|
(32, 2, 4), this function will pad the value of shape to (64, 2, 8).
|
|
|
|
Args:
|
|
value: Array to be padded.
|
|
block_shape: Block shapes to use for padding. If None, no padding will
|
|
be performed.
|
|
|
|
Returns:
|
|
A padded array.
|
|
"""
|
|
if block_shape is None:
|
|
return value
|
|
padded_shape = tuple(
|
|
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
|
|
)
|
|
if padded_shape != value.shape:
|
|
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
|
pad_value = uninitialized_value(shape=(), dtype=value.dtype)
|
|
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
|
return value
|
|
|
|
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
|
|
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
|
|
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
|
|
|
|
def _initialize_output_vals(
|
|
out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]:
|
|
oi_map = {v: k for k, v in input_output_aliases}
|
|
output_vals = []
|
|
for i, out_shape in enumerate(out_shapes):
|
|
if i in oi_map:
|
|
output_vals.append(input_args[oi_map[i]])
|
|
else:
|
|
output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype))
|
|
return output_vals
|
|
|
|
def _logical_to_interpret_mode_dtype(dtype):
|
|
"""Converts logical dtypes into JAX dtypes for interpret mode.
|
|
|
|
This function is used to convert device-specific dtypes that have no
|
|
corresponding equivalent in JAX/XLA into a type that can be executed
|
|
by the XLA interpreter (e.g. TPU semaphores -> int32).
|
|
"""
|
|
if (hasattr(dtype, "_rules") and
|
|
hasattr(dtype._rules, "pallas_interpret_element_aval")):
|
|
return dtype._rules.pallas_interpret_element_aval(dtype).dtype
|
|
return dtype
|
|
|
|
def _logical_aval_to_interpret_mode_aval(aval):
|
|
"""Logical to interpret mode aval conversion."""
|
|
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
|
inner_aval = _logical_aval_to_interpret_mode_aval(aval.inner_aval)
|
|
return aval.update(inner_aval=inner_aval)
|
|
if isinstance(aval, jax_core.ShapedArray):
|
|
inner_dtype = _logical_to_interpret_mode_dtype(aval.dtype)
|
|
return jax_core.ShapedArray(aval.shape,
|
|
inner_dtype,
|
|
weak_type=aval.weak_type, named_shape=aval.named_shape)
|
|
return aval
|
|
|
|
def _get_next_indices(grid, indices):
|
|
next_indices = []
|
|
carry = True
|
|
for dim_size, index in reversed(list(zip(grid, indices))):
|
|
i = jnp.where(carry, index + 1, index)
|
|
carry = dim_size == i
|
|
next_indices.append(jnp.where(carry, 0, i))
|
|
return tuple(reversed(next_indices))
|
|
|
|
def _pallas_call_impl(*args, jaxpr, name, out_shapes,
|
|
interpret, debug: bool,
|
|
in_shapes,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
grid_mapping: GridMapping,
|
|
compiler_params: Any):
|
|
dynamic_grid_args, args = split_list( # type: ignore
|
|
args, [grid_mapping.num_dynamic_grid_bounds]
|
|
)
|
|
if interpret:
|
|
# If we're in interpreter mode, we *scan* over the grid and eval the
|
|
# discharged jaxpr. This should reproduce exactly what compiling to Triton
|
|
# will do.
|
|
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
|
grid = tuple(
|
|
a if a is not pallas_core.dynamic_grid_dim
|
|
else next(dynamic_grid_args_iter)
|
|
for a in grid_mapping.grid
|
|
)
|
|
assert next(dynamic_grid_args_iter, None) is None
|
|
with grid_mapping.trace_env():
|
|
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
|
if debug:
|
|
print(discharged_jaxpr)
|
|
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
|
|
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
|
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
|
|
num_invars = len(jaxpr.invars)
|
|
num_inputs_outputs = (
|
|
num_invars
|
|
- grid_mapping.num_index_operands
|
|
- grid_mapping.num_scratch_operands
|
|
)
|
|
_, _, scratch_invars = split_list(
|
|
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
|
|
)
|
|
scratch_avals = [v.aval for v in scratch_invars]
|
|
scratch_values = _initialize_scratch_vals(scratch_avals)
|
|
|
|
carry = []
|
|
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
|
|
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
|
padding = bm.indexing_mode.padding
|
|
if padding is not None and any(p != (0, 0) for p in padding):
|
|
if input_output_aliases:
|
|
raise NotImplementedError("Padding with aliasing not supported.")
|
|
x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding])
|
|
carry.append(x)
|
|
|
|
block_shapes_without_mapped_dims = [
|
|
None if block_mapping is None else block_mapping.block_shape
|
|
for block_mapping in grid_mapping.block_mappings
|
|
]
|
|
is_indexing_dim = [
|
|
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
|
|
for bm in block_shapes_without_mapped_dims
|
|
]
|
|
block_shapes = [
|
|
None if (bm is None or iid is None)
|
|
else tuple(1 if i else b for i, b in zip(iid, bm))
|
|
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
|
|
]
|
|
|
|
# Pad values to evenly divide into block dimensions.
|
|
# This allows interpret mode to catch errors on OOB memory accesses
|
|
# by poisoning values with NaN. It also fixes an inconstency with
|
|
# lax.dynamic_slice where if the slice goes out of bounds, it will instead
|
|
# move the start_index backwards so the slice will fit in memory.
|
|
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
|
|
carry.extend(scratch_values)
|
|
|
|
num_inout = len(args) + len(out)
|
|
grid_start_indices = (jnp.int32(0),) * len(grid)
|
|
if grid:
|
|
num_iterations = reduce(jnp.multiply, grid)
|
|
else:
|
|
# Base case is always one iteration when grid is ()
|
|
num_iterations = 1
|
|
def cond(carry):
|
|
i, *_ = carry
|
|
return i < num_iterations
|
|
def body(carry):
|
|
i, loop_idx, *carry = carry
|
|
local_grid_env = tuple(
|
|
pallas_core.GridAxis(idx, b)
|
|
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
|
if dim not in grid_mapping.mapped_dims
|
|
)
|
|
carry, scratch = split_list(carry, [num_inout])
|
|
with pallas_core.grid_env(local_grid_env):
|
|
start_indices = [
|
|
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
|
for bm in grid_mapping.block_mappings]
|
|
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
|
is_indexing_dim)
|
|
with pallas_core.grid_env(local_grid_env):
|
|
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
|
|
scratch_values
|
|
), (
|
|
len(discharged_jaxpr.invars),
|
|
len(scalars),
|
|
len(blocks),
|
|
len(scratch_values),
|
|
)
|
|
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
|
|
*blocks, *scratch)
|
|
blocks = blocks[grid_mapping.num_index_operands:]
|
|
blocks, out_scratch = split_list(blocks, [num_inout])
|
|
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
|
carry, blocks, is_indexing_dim)
|
|
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
|
|
(_, _, *carry) = lax.while_loop(
|
|
cond, body, (jnp.int32(0), grid_start_indices, *carry)
|
|
)
|
|
_, out, _ = split_list(carry, [len(args), len(out)])
|
|
assert len(grid_mapping.block_mappings) == len(args) + len(out)
|
|
out_block_mappings = grid_mapping.block_mappings[len(args):]
|
|
out_nopad = []
|
|
for o, bm in zip(out, out_block_mappings):
|
|
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
|
padding = bm.indexing_mode.padding
|
|
if padding is not None and any(p != (0, 0) for p in padding):
|
|
if input_output_aliases:
|
|
raise NotImplementedError("Padding with aliasing not supported.")
|
|
pad_low, pad_high = zip(*padding)
|
|
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
|
|
o = lax.slice(o, pad_low, limit_indices)
|
|
out_nopad.append(o)
|
|
return out_nopad
|
|
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
|
|
in_shapes=in_shapes,
|
|
out_shapes=out_shapes,
|
|
grid_mapping=grid_mapping, interpret=interpret,
|
|
debug=debug,
|
|
input_output_aliases=input_output_aliases,
|
|
compiler_params=compiler_params)
|
|
pallas_call_p.def_impl(_pallas_call_impl)
|
|
|
|
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
|
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
|
|
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
|
|
|
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
|
|
if grid_mapping.num_dynamic_grid_bounds:
|
|
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
|
if grid_mapping.num_index_operands:
|
|
raise NotImplementedError
|
|
if input_output_aliases:
|
|
raise NotImplementedError("JVP with aliasing not supported.")
|
|
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
|
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
|
nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes)
|
|
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
|
|
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
|
|
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
|
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
|
|
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
|
|
# `Ref`s that are read from followed by output `Ref`s that are written to.
|
|
# This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new
|
|
# jaxpr that has tangents following primals. In order for this jaxpr to be
|
|
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
|
|
# the jaxpr's invars.
|
|
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
|
|
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
|
|
)
|
|
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
|
|
effs = []
|
|
for eff in jvp_jaxpr.effects:
|
|
if isinstance(eff, effects.JaxprInputEffect):
|
|
eff = eff.replace(
|
|
input_index=invars.index(jvp_jaxpr.invars[eff.input_index])
|
|
)
|
|
effs.append(eff)
|
|
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
|
|
if debug:
|
|
print(jvp_jaxpr)
|
|
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
|
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
|
out_flat = pallas_call_p.bind(
|
|
*primals,
|
|
*tangents,
|
|
jaxpr=jvp_jaxpr,
|
|
name=f"{name}_jvp",
|
|
in_shapes=(*in_shapes, *in_shapes),
|
|
out_shapes=(*out_shapes, *out_shapes),
|
|
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
|
|
interpret=interpret,
|
|
debug=debug,
|
|
input_output_aliases=(),
|
|
compiler_params=compiler_params,
|
|
)
|
|
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
|
return out_primals, out_tangents
|
|
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
|
|
|
def _batch_block_mapping(grid_mapping: GridMapping, aval: jax_core.ShapedArray,
|
|
dim: int | batching.NotMapped,
|
|
block_mapping: BlockMapping | None) -> BlockMapping:
|
|
def _block_map_function(new_idx, *args):
|
|
if block_mapping is None:
|
|
indices = [0] * len(aval.shape)
|
|
else:
|
|
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
|
|
block_mapping.index_map_jaxpr.consts,
|
|
*args)
|
|
if dim is not batching.not_mapped:
|
|
indices.insert(dim, new_idx)
|
|
return tuple(indices)
|
|
i32_aval = jax_core.ShapedArray((), jnp.int32)
|
|
if block_mapping is None:
|
|
idx_avals = [i32_aval] * (len(grid_mapping.grid) + 1)
|
|
else:
|
|
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
|
with grid_mapping.trace_env():
|
|
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(_block_map_function), idx_avals)
|
|
shape = aval.shape if block_mapping is None else block_mapping.block_shape
|
|
if dim is batching.not_mapped:
|
|
new_block_shape = shape
|
|
else:
|
|
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
|
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
|
if block_mapping is None:
|
|
return BlockMapping(
|
|
block_shape=new_block_shape,
|
|
index_map_jaxpr=jaxpr,
|
|
indexing_mode=pallas_core.blocked,
|
|
)
|
|
return block_mapping.replace(block_shape=new_block_shape,
|
|
index_map_jaxpr=jaxpr)
|
|
|
|
|
|
def _broadcast_input_output_aliases(
|
|
args: Sequence[jax.Array],
|
|
dims: Sequence[int | batching.NotMapped],
|
|
*,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
axis_size: int,
|
|
) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]:
|
|
"""Broadcast input/output operands.
|
|
|
|
When we have input/output aliasing, since the output will be mapped, we need
|
|
to make sure to broadcast the input across that dimension if it is not
|
|
mapped. If the input is mapped, but on a different axis, we tranpose the input
|
|
to match the output.
|
|
"""
|
|
|
|
args_ = list(args)
|
|
dims_ = list(dims)
|
|
for input_index, _ in input_output_aliases:
|
|
dim = dims_[input_index]
|
|
dims_[input_index] = 0
|
|
if dim is batching.not_mapped:
|
|
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
|
elif dim != 0:
|
|
# TODO(cjfj): Change output batching axis instead?
|
|
args_[input_index] = jnp.moveaxis(args[input_index], dim, 0)
|
|
|
|
return tuple(args_), tuple(dims_)
|
|
|
|
|
|
def _batch_with_explicit_loop(
|
|
args: Sequence[jax.Array],
|
|
dims: Sequence[int | batching.NotMapped],
|
|
*,
|
|
jaxpr: jax_core.Jaxpr,
|
|
name: str,
|
|
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
|
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
|
grid_mapping: GridMapping,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
debug: bool,
|
|
interpret: bool,
|
|
compiler_params: Any,
|
|
):
|
|
"""Batch the pallas_call by calling it in loop over the batch size.
|
|
|
|
This function provides a fallback implementation of batching a pallas_call
|
|
for the cases in which adding a batch dimension to the pallas grid is not
|
|
supported. This is currently the case when the batched dimension corresponds
|
|
to a dynamic axis or a scalar prefetch argument.
|
|
|
|
This implementation builds a HLO loop that dynamic_slices the inputs according
|
|
to the current iteration index and dynamic_updates an (initially empty) output
|
|
allocation.
|
|
"""
|
|
|
|
if not dims:
|
|
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
|
|
|
(axis_size,) = {
|
|
arg.shape[dim]
|
|
for arg, dim in zip(args, dims)
|
|
if dim is not batching.not_mapped
|
|
}
|
|
|
|
args, dims = _broadcast_input_output_aliases(
|
|
args,
|
|
dims,
|
|
input_output_aliases=input_output_aliases,
|
|
axis_size=axis_size,
|
|
)
|
|
|
|
# The output arrays are completelly overwritten, so we can just initialize
|
|
# empty arrays.
|
|
initial_state = [
|
|
jnp.empty(
|
|
tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype
|
|
)
|
|
for out_shape in out_shapes
|
|
]
|
|
|
|
def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
|
|
batch_args = []
|
|
|
|
for arg, dim in zip(args, dims):
|
|
# If the argument is mapped, extract a slice of size 1 in the mapped
|
|
# dimension at the current index.
|
|
if dim is batching.not_mapped:
|
|
batch_args.append(arg)
|
|
else:
|
|
batch_args.append(
|
|
jnp.squeeze(
|
|
jax.lax.dynamic_slice_in_dim(
|
|
operand=arg,
|
|
start_index=batch_index,
|
|
slice_size=1,
|
|
axis=dim,
|
|
),
|
|
axis=dim,
|
|
)
|
|
)
|
|
|
|
batch_out = pallas_call_p.bind(
|
|
*batch_args,
|
|
jaxpr=jaxpr,
|
|
name=name,
|
|
in_shapes=in_shapes,
|
|
out_shapes=out_shapes,
|
|
grid_mapping=grid_mapping,
|
|
input_output_aliases=input_output_aliases,
|
|
debug=debug,
|
|
interpret=interpret,
|
|
compiler_params=compiler_params,
|
|
)
|
|
for i, batch_out_array in enumerate(batch_out):
|
|
state[i] = jax.lax.dynamic_update_index_in_dim(
|
|
state[i],
|
|
batch_out_array,
|
|
batch_index,
|
|
axis=0,
|
|
)
|
|
|
|
return state
|
|
|
|
result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False)
|
|
|
|
return result, (0,) * len(result)
|
|
|
|
|
|
def _pallas_call_batching_rule(
|
|
args,
|
|
dims,
|
|
*,
|
|
jaxpr: jax_core.Jaxpr,
|
|
name: str,
|
|
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
|
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
|
grid_mapping: GridMapping,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
debug: bool,
|
|
interpret: bool,
|
|
compiler_params: Any,
|
|
):
|
|
|
|
def _maybe_squeeze_out_bdim(
|
|
x: jax.Array, bdim: int | batching.NotMapped
|
|
) -> jax.Array:
|
|
if bdim is batching.not_mapped:
|
|
return x
|
|
return jnp.squeeze(x, axis=bdim)
|
|
|
|
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
|
if d is not batching.not_mapped}
|
|
if axis_size == 1:
|
|
# Why are we even vmapping?
|
|
args = map(_maybe_squeeze_out_bdim, args, dims)
|
|
out = pallas_call_p.bind(
|
|
*args,
|
|
jaxpr=jaxpr,
|
|
name=name,
|
|
in_shapes=in_shapes,
|
|
out_shapes=out_shapes,
|
|
grid_mapping=grid_mapping,
|
|
input_output_aliases=input_output_aliases,
|
|
debug=debug,
|
|
interpret=interpret,
|
|
compiler_params=compiler_params,
|
|
)
|
|
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
|
|
|
|
# The first num_dynamic_grid_bounds arguments are size-1 arrays that store
|
|
# the size of the dynamic bounds.
|
|
dynamic_grid_args, args = split_list(
|
|
args, [grid_mapping.num_dynamic_grid_bounds]
|
|
)
|
|
dynamic_grid_dims, dims = split_list(
|
|
dims, [grid_mapping.num_dynamic_grid_bounds]
|
|
)
|
|
if all(
|
|
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
|
for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims)
|
|
):
|
|
dynamic_grid_args = safe_map(
|
|
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims
|
|
)
|
|
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
|
|
# TODO(amagni, sharadmv): Explore possibility of batching dynamic grid
|
|
# bounds.
|
|
return _batch_with_explicit_loop(
|
|
args=dynamic_grid_args + args,
|
|
dims=dynamic_grid_dims + dims,
|
|
jaxpr=jaxpr,
|
|
name=name,
|
|
in_shapes=in_shapes,
|
|
out_shapes=out_shapes,
|
|
grid_mapping=grid_mapping,
|
|
input_output_aliases=input_output_aliases,
|
|
debug=debug,
|
|
interpret=interpret,
|
|
compiler_params=compiler_params,
|
|
)
|
|
else:
|
|
pass # No dynamic grid dimensions
|
|
del dynamic_grid_dims
|
|
if grid_mapping.num_index_operands:
|
|
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
|
|
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
|
|
# Ordinarily, adding support for scalar prefetch in vmap would involve
|
|
# modifying the block specs in a nontrivial way. However, if we are only
|
|
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
|
|
# and pretend we were never vmapping over them at all.
|
|
if all(
|
|
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
|
for arg, bdim in zip(scalar_args, scalar_bdims)
|
|
):
|
|
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
|
|
scalar_bdims = [None] * len(scalar_args)
|
|
args = (*scalar_args, *args)
|
|
dims = (*scalar_bdims, *bdims)
|
|
else:
|
|
# TODO(amagni,sharadmv,apaszke): enable efficient batching over
|
|
# prefetched scalar args.
|
|
return _batch_with_explicit_loop(
|
|
args=scalar_args + args,
|
|
dims=scalar_bdims + bdims,
|
|
jaxpr=jaxpr,
|
|
name=name,
|
|
in_shapes=in_shapes,
|
|
out_shapes=out_shapes,
|
|
grid_mapping=grid_mapping,
|
|
input_output_aliases=input_output_aliases,
|
|
debug=debug,
|
|
interpret=interpret,
|
|
compiler_params=compiler_params,
|
|
)
|
|
|
|
if not dims:
|
|
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
|
block_mappings = grid_mapping.block_mappings
|
|
avals = [v.aval for v in jaxpr.invars]
|
|
# How should we pick output dimensions? This actually matters because XLA
|
|
# can't optimize our pallas kernels, and this layout impacts performance. For
|
|
# now, because `vmap` doesn't really offer a way of inferring good output
|
|
# dimensions. For now, we just use 0.
|
|
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
|
|
# TODO(sharadmv): explore a long term solution to output dim inference
|
|
|
|
args, dims = _broadcast_input_output_aliases(
|
|
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
|
|
)
|
|
|
|
all_dims = list(dims) + [0] * len(out_shapes)
|
|
|
|
num_index_operands = grid_mapping.num_index_operands
|
|
num_scratch_operands = grid_mapping.num_scratch_operands
|
|
|
|
# Only add a batch dimension for the avals that actually have a grid mapping.
|
|
# This excludes scalar prefetch inputs (the first in the list) and scratch
|
|
# operands (the last in the list).
|
|
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
|
|
batched_block_mappings = map(
|
|
partial(_batch_block_mapping, grid_mapping),
|
|
avals_to_batch,
|
|
all_dims[num_index_operands:],
|
|
block_mappings,
|
|
)
|
|
|
|
batched_in_shapes = tuple(
|
|
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
|
|
tuple_insert(x.shape, dim, axis_size),
|
|
x.dtype)
|
|
for x, dim in zip(in_shapes, dims))
|
|
batched_out_shapes = tuple(
|
|
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
|
|
for x in out_shapes)
|
|
|
|
batched_grid_mapping = grid_mapping.replace(
|
|
grid=(axis_size, *grid_mapping.grid),
|
|
block_mappings=tuple(batched_block_mappings),
|
|
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
|
|
out = pallas_call_p.bind(
|
|
*dynamic_grid_args,
|
|
*args,
|
|
jaxpr=jaxpr,
|
|
name=f"batched_{name}",
|
|
in_shapes=batched_in_shapes,
|
|
out_shapes=batched_out_shapes,
|
|
grid_mapping=batched_grid_mapping,
|
|
input_output_aliases=input_output_aliases,
|
|
debug=debug,
|
|
interpret=interpret,
|
|
compiler_params=compiler_params,
|
|
)
|
|
return out, (0,) * len(out)
|
|
|
|
|
|
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
|
|
|
def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
|
"""Hoists the constants in the given jaxpr into invars.
|
|
|
|
Args:
|
|
jaxpr: The jaxpr.
|
|
|
|
Returns:
|
|
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
|
|
The invars for the constants are added *before* any existing invars.
|
|
"""
|
|
if not jaxpr.constvars:
|
|
return jaxpr # Nothing to hoist.
|
|
|
|
is_const_ref = [
|
|
isinstance(var.aval, state.AbstractRef) for var in jaxpr.constvars
|
|
]
|
|
const_avals = [
|
|
var.aval if is_ref else state.AbstractRef(var.aval)
|
|
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
|
|
]
|
|
in_avals = const_avals + [var.aval for var in jaxpr.invars]
|
|
|
|
def _hoist(*consts_args):
|
|
all_consts, args = split_list(consts_args, [len(const_avals)])
|
|
# We immediately read the const values out of the `Ref`s.
|
|
all_consts = [
|
|
c if is_ref else sp.ref_get(c, ())
|
|
for is_ref, c in zip(is_const_ref, all_consts)
|
|
]
|
|
return jax_core.eval_jaxpr(jaxpr, all_consts, *args)
|
|
|
|
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(_hoist), in_avals)
|
|
assert not consts, "All consts should have been converted to refs"
|
|
return hoisted_jaxpr
|
|
|
|
|
|
def checkify_pallas_kernel_body_jaxpr(
|
|
body_jaxpr: jax_core.ClosedJaxpr,
|
|
enabled_errors,
|
|
error: checkify.Error,
|
|
grid_mapping: GridMapping) -> tuple[
|
|
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
|
|
err_vals, err_tree = tree_util.tree_flatten(error)
|
|
err_vals = map(checkify.get_shaped_aval, err_vals)
|
|
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
|
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr(
|
|
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
|
return checked_jaxpr, out_tree, error_effects
|
|
|
|
def pallas_call_checkify_rule(error: checkify.Error,
|
|
enabled_errors,
|
|
*args: jax_core.Value,
|
|
jaxpr: jax_core.Jaxpr,
|
|
interpret: bool,
|
|
input_output_aliases: tuple[tuple[int, int], ...],
|
|
grid_mapping: GridMapping,
|
|
out_shapes,
|
|
**kwargs):
|
|
# TODO(b/346651778): Support TPU/GPU checkify.
|
|
if not interpret:
|
|
raise NotImplementedError(
|
|
"Checkify for pallas_call only supports interpret mode.")
|
|
# We implement the checkify rule in 4 steps:
|
|
# 1) First, trace the kernel body to get the expected error shapes.
|
|
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
|
|
# and outputs.
|
|
# 3) Create a new kernel which stores the errors in output memrefs instead of
|
|
# returning them, since pallas kernels do not return outputs.
|
|
# 4) Create block specs for the error state and call pallas_call with
|
|
# the new kernel.
|
|
dynamic_grid_bounds, scalars, args = split_list( # type: ignore
|
|
args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands]
|
|
)
|
|
num_scalars = len(scalars)
|
|
num_invars = len(jaxpr.invars)
|
|
num_inputs_outputs = (
|
|
num_invars
|
|
- grid_mapping.num_index_operands
|
|
- grid_mapping.num_scratch_operands
|
|
)
|
|
num_kernel_inputs = len(args)
|
|
num_scratch = num_invars - num_inputs_outputs
|
|
num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs
|
|
|
|
# Trace the jaxpr to get an initial error value so the kernel jaxpr has all of
|
|
# the required inputs.
|
|
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
|
_jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr(
|
|
closed_jaxpr, enabled_errors, error, grid_mapping)
|
|
error = error._add_placeholder_effects(error_effects)
|
|
err_vals, err_tree = jax.tree.flatten(error)
|
|
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
|
|
|
|
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
|
|
# all enabled errors removed, but have the error as inputs and return values.
|
|
input_avals = [v.aval for v in jaxpr.invars]
|
|
num_err_vals = len(err_vals)
|
|
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
|
|
checkify_in_avals = [*shaped_err_avals,
|
|
*shaped_input_avals]
|
|
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
|
|
closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals)
|
|
|
|
# Create a new kernel to remove the error as an return value and instead
|
|
# write them to a memref. This is because pallas kernels are expected
|
|
# to have no return values but instead write their outputs to a ref.
|
|
def checked_kernel_fn(*args):
|
|
(scalars, _, inputs, out_error_refs, outputs, scratch
|
|
) = split_list(
|
|
args,
|
|
[num_scalars, num_err_vals,
|
|
num_kernel_inputs, num_err_vals, num_kernel_outputs])
|
|
input_error_vals = [err_ref[...] for err_ref in out_error_refs]
|
|
# We need to re-order the inputs here. A checkified jaxpr always expects
|
|
# errors before other arguments.
|
|
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
|
|
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
|
|
result_flat = jax.core.eval_jaxpr(
|
|
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
|
|
output_errors, _ = split_list(result_flat, [num_err_vals])
|
|
# Store new errors back in the error refs.
|
|
for out_ref, error in zip(out_error_refs, output_errors):
|
|
out_ref[...] = error
|
|
return []
|
|
|
|
# Trace the new checked_kernel_fn with Memref inputs so that
|
|
# we can replace the old kernel jaxpr with the new checked jaxpr in
|
|
# pallas_call.
|
|
# TODO(justinfu): Place errors in scalar memory for non-interpret mode.
|
|
error_mem_space = None
|
|
error_memref_aval = [pallas_core.AbstractMemoryRef(
|
|
err_val, error_mem_space) for err_val in shaped_err_avals]
|
|
shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list(
|
|
shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs])
|
|
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
|
|
*error_memref_aval, *output_aval, *scratch_aval]
|
|
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
|
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
|
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
|
debug = pe.debug_info(
|
|
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
|
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
|
|
|
|
# Prepare pallas_call inputs. We need to create new block specs
|
|
# for the new error inputs and outputs.
|
|
scalar_avals = map(checkify.get_shaped_aval, scalars)
|
|
error_block_specs = [no_block_spec] * num_err_vals
|
|
grid_avals = [
|
|
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
|
# TODO(justinfu): Place these in device-specific scalar memory.
|
|
scalar_ref_avals = [
|
|
pallas_core.AbstractMemoryRef(
|
|
jax_core.ShapedArray(aval.shape, aval.dtype), None)
|
|
for aval in scalar_avals]
|
|
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
|
|
error_block_mappings = map(
|
|
partial(
|
|
pallas_core._convert_block_spec_to_block_mapping,
|
|
(*grid_avals, *scalar_ref_avals),
|
|
in_tree=grid_tree,
|
|
grid=grid_mapping.grid,
|
|
mapped_dims=grid_mapping.mapped_dims),
|
|
error_block_specs, error_memref_aval)
|
|
input_block_mappings, output_block_mappings = split_list(
|
|
grid_mapping.block_mappings, [num_kernel_inputs,])
|
|
grid_mapping_with_error = grid_mapping.replace(
|
|
block_mappings=(*error_block_mappings, *input_block_mappings,
|
|
*error_block_mappings, *output_block_mappings)
|
|
)
|
|
error_out_shapes = tuple(
|
|
jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals)
|
|
# Bump all input_output_aliases by num_err_vals to make room for error
|
|
# TODO(justinfu): Don't bump scalars here.
|
|
input_output_aliases = tuple(
|
|
(i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases)
|
|
input_output_aliases_with_error = tuple(
|
|
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
|
|
|
|
new_vals_in = [*scalars, *err_vals, *args]
|
|
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
|
|
jaxpr=final_jaxpr,
|
|
interpret=interpret,
|
|
grid_mapping=grid_mapping_with_error,
|
|
input_output_aliases=input_output_aliases_with_error,
|
|
out_shapes=error_out_shapes + out_shapes,
|
|
**kwargs)
|
|
errors, results = split_list(result, [num_err_vals])
|
|
new_error, _ = jax.tree.unflatten(out_tree, errors)
|
|
return new_error, results
|
|
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
|
|
|
@weakref_lru_cache
|
|
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
|
|
flat_out_avals, in_tree, out_tree, interpret: bool):
|
|
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree,
|
|
flat_out_avals, out_tree)
|
|
if interpret:
|
|
avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals)
|
|
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
|
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
|
lu.wrap_init(fun), jaxpr_in_tree)
|
|
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
|
|
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
|
|
jaxpr_flat_avals, debug)
|
|
if consts:
|
|
jaxpr = _hoist_consts_to_refs(jaxpr)
|
|
# Pad ``block_mappings`` to account for the hoisted constants.
|
|
grid_mapping = grid_mapping.replace(
|
|
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
|
|
num_constant_operands=len(consts),
|
|
)
|
|
return grid_mapping, jaxpr, consts, out_tree_thunk()
|
|
|
|
def _extract_function_name(f: Callable, name: str | None) -> str:
|
|
if name is None:
|
|
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
|
|
return name
|
|
|
|
|
|
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
|
|
"jax_pallas_use_mosaic_gpu",
|
|
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
|
|
help=(
|
|
"If True, lower Pallas kernels to the experimental Mosaic GPU"
|
|
" dialect, instead of Trition IR."
|
|
),
|
|
)
|
|
|
|
|
|
def _unsupported_lowering_error(platform: str) -> Exception:
|
|
return ValueError(
|
|
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
|
|
" install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install"
|
|
" jaxlib TPU and libtpu. See"
|
|
" https://jax.readthedocs.io/en/latest/installation.html."
|
|
)
|
|
|
|
|
|
def _pallas_call_lowering(
|
|
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
|
|
):
|
|
if interpret:
|
|
# If we are in interpret mode, we don't care what platform we are on.
|
|
impl = partial(_pallas_call_impl, **params, interpret=True)
|
|
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
|
|
|
def cpu_lowering(ctx: mlir.LoweringRuleContext,
|
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
|
**params):
|
|
raise ValueError("Only interpret mode is supported on CPU backend.")
|
|
|
|
def tpu_lowering(ctx: mlir.LoweringRuleContext,
|
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
|
**params):
|
|
try:
|
|
from jax._src.pallas.mosaic import pallas_call_registration
|
|
except ImportError:
|
|
raise _unsupported_lowering_error("tpu")
|
|
else:
|
|
return pallas_call_registration.pallas_call_tpu_lowering_rule(
|
|
ctx, *in_nodes, **params
|
|
)
|
|
|
|
def gpu_lowering(ctx: mlir.LoweringRuleContext,
|
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
|
**params):
|
|
try:
|
|
if _PALLAS_USE_MOSAIC_GPU.value:
|
|
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
|
else:
|
|
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
|
except ImportError:
|
|
raise _unsupported_lowering_error("gpu")
|
|
else:
|
|
return pallas_call_registration.pallas_call_lowering(
|
|
ctx, *in_nodes, **params
|
|
)
|
|
|
|
return mlir.lower_per_platform(ctx, "pallas_call",
|
|
dict(cpu=cpu_lowering,
|
|
tpu=tpu_lowering,
|
|
cuda=gpu_lowering,
|
|
rocm=gpu_lowering),
|
|
None, # default_rule
|
|
effects.no_effects,
|
|
*in_nodes,
|
|
interpret=interpret,
|
|
**params)
|
|
|
|
|
|
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
|
|
|
|
|
def pallas_call(
|
|
f: Callable[..., None],
|
|
out_shape: Any,
|
|
*,
|
|
grid_spec: GridSpec | None = None,
|
|
debug: bool = False,
|
|
grid: Grid | None = None,
|
|
in_specs: BlockSpecTree = no_block_spec,
|
|
out_specs: BlockSpecTree = no_block_spec,
|
|
input_output_aliases: dict[int, int] = {},
|
|
interpret: bool = False,
|
|
name: str | None = None,
|
|
compiler_params: dict[str, Any] | None = None,
|
|
) -> Callable[..., Any]:
|
|
name = _extract_function_name(f, name)
|
|
if compiler_params is None:
|
|
compiler_params = {}
|
|
if grid is not None and grid_spec is not None:
|
|
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
|
|
if grid_spec is None:
|
|
grid_spec = GridSpec(grid, in_specs, out_specs)
|
|
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
|
if isinstance(out_shape, list):
|
|
out_shape = tuple(out_shape)
|
|
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
|
|
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
|
|
for x in flat_out_shapes]
|
|
@jax.jit
|
|
def wrapped(*args):
|
|
flat_args, in_tree = tree_util.tree_flatten(args)
|
|
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
|
|
for a in flat_args)
|
|
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
|
for v in flat_out_shapes)
|
|
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
|
|
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
|
|
out_tree, interpret=interpret)
|
|
out_flat = pallas_call_p.bind(
|
|
*dynamic_grid_bounds, *consts, *flat_args,
|
|
jaxpr=jaxpr, name=name,
|
|
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
|
for a in flat_args),
|
|
out_shapes=tuple(flat_out_shapes), debug=debug,
|
|
interpret=interpret,
|
|
grid_mapping=grid_mapping,
|
|
input_output_aliases=tuple(input_output_aliases.items()),
|
|
compiler_params=compiler_params)
|
|
out = tree_util.tree_unflatten(out_tree, out_flat)
|
|
return out
|
|
return wrapped
|