rocm_jax/jax/_src/pallas/primitives.py
George Necula 6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00

846 lines
28 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas-specific JAX primitives."""
from __future__ import annotations
import enum
import functools
import string
from typing import Any, Callable
import jax
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax.interpreters import mlir
import jax.numpy as jnp
partial = functools.partial
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
program_id_p = jax_core.Primitive("program_id")
def program_id(axis: int) -> jax.Array:
"""Returns the kernel execution position along the given axis of the grid.
For example, with a 2D `grid` in the kernel execution corresponding to the
grid coordinates `(1, 2)`,
`program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
Args:
axis: the axis of the grid along which to count the program.
"""
return program_id_p.bind(axis=axis)
@program_id_p.def_custom_bind
def program_id_bind(*, axis: int):
grid_env = pallas_core.current_grid_env()
if grid_env:
return grid_env[axis].index
frame = pallas_core.axis_frame()
# Query the size of the axis to make sure it's a valid axis (and error
# otherwise).
_ = frame.size(axis)
return jax_core.Primitive.bind(program_id_p, axis=axis)
@program_id_p.def_abstract_eval
def _program_id_abstract_eval(**_):
return jax_core.ShapedArray((), jnp.int32)
num_programs_p = jax_core.Primitive("num_programs")
def num_programs(axis: int) -> int | jax.Array:
"""Returns the size of the grid along the given axis."""
return num_programs_p.bind(axis=axis)
@num_programs_p.def_custom_bind
def _num_programs_bind(*, axis: int):
# We might be using a local grid env
grid_env = pallas_core.current_grid_env()
if grid_env:
return grid_env[axis].size
# Otherwise, we look up the size of the grid in the axis env
frame = pallas_core.axis_frame()
size = frame.size(axis)
if size is pallas_core.dynamic_grid_dim:
return jax_core.Primitive.bind(num_programs_p, axis=axis)
return size
@num_programs_p.def_abstract_eval
def _num_programs_abstract_eval(**_):
return jax_core.ShapedArray((), jnp.int32)
class AtomicOpType(enum.Enum):
XCHG = "xchg"
ADD = "add"
MAX = "max"
MIN = "min"
AND = "and"
OR = "or"
XOR = "xor"
atomic_rmw_p = jax_core.Primitive("atomic_rmw")
def _atomic_rmw_discharge_rule(
in_avals, out_avals, *args_flat, args_tree, atomic_type: AtomicOpType
):
del out_avals # Unused.
ref, indexers, val, mask = args_tree.unflatten(args_flat)
if len(indexers) > 1:
raise NotImplementedError("Only one indexer is supported.")
idx = indexers[0]
if mask is not None:
raise NotImplementedError
if atomic_type == AtomicOpType.ADD:
monoid = lambda x, y: x + y
elif atomic_type == AtomicOpType.MAX:
monoid = jnp.maximum
elif atomic_type == AtomicOpType.MIN:
monoid = jnp.minimum
else:
raise NotImplementedError(atomic_type)
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
val = val[val_indexer]
val = monoid(val, out_ones)
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
x_new = ref.at[idx.indices].set(monoid(out, val))
else:
raise NotImplementedError
return (x_new,) + (None,) * (len(in_avals) - 1), out
state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule)
@atomic_rmw_p.def_effectful_abstract_eval
def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
ref, _, _, _ = args_tree.unflatten(avals_flat)
if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.")
if ref.dtype in {
jnp.dtype("bool"),
jnp.dtype("int8"),
jnp.dtype("int16"),
jnp.bfloat16,
}:
raise ValueError(
f"`atomic_{atomic_type.value}` does not support {ref.dtype}."
)
return _swap_abstract_eval(*avals_flat, args_tree=args_tree)
def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
atomic_type: AtomicOpType):
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "atomic_rmw")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
return atomic_rmw_p.bind(
*args_flat, args_tree=args_tree, atomic_type=atomic_type
)
def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically exchanges the given value with the value at the given index.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the aupdate.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG
)
def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] += val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD
)
def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX
)
def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN
)
def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] &= val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND
)
def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] |= val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR
)
def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] ^= val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR
)
atomic_cas_p = jax_core.Primitive("atomic_cas")
@atomic_cas_p.def_effectful_abstract_eval
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape:
raise ValueError("cmp and val must have identical dtypes and shapes")
if ref_aval.shape:
raise ValueError("ref must be scalar.")
if cmp_aval.shape:
raise ValueError("cmp must be scalar.")
if val_aval.shape:
raise ValueError("val must be scalar.")
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)}
def atomic_cas(ref, cmp, val):
"""Performs an atomic compare-and-swap of the value in the ref with the
given value.
Args:
ref: The ref to operate on.
cmp: The expected value to compare against.
val: The value to swap in.
Returns:
The value at the given index prior to the atomic operation.
"""
return atomic_cas_p.bind(ref, cmp, val)
@state_discharge.register_discharge_rule(atomic_cas_p)
def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val):
del in_avals, out_avals
new_val = jnp.where(ref == cmp, val, ref)
return (new_val, None, None), ref
max_contiguous_p = jax_core.Primitive("max_contiguous")
max_contiguous_p.def_impl(lambda x, **_: x)
mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x])
def max_contiguous(x, values):
if not isinstance(values, list):
values = [values]
return max_contiguous_p.bind(x, values=values)
@max_contiguous_p.def_abstract_eval
def _max_contiguous_abstract_eval(aval, **_):
return aval
multiple_of_p = jax_core.Primitive("multiple_of")
multiple_of_p.def_impl(lambda x, **_: x)
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array:
if not isinstance(values, list):
values = [values]
return multiple_of_p.bind(x, values=values)
@multiple_of_p.def_abstract_eval
def _multiple_of_abstract_eval(aval, **_):
return aval
load_p = jax_core.Primitive('masked_load')
@load_p.def_effectful_abstract_eval
def _load_abstract_eval(*avals_flat, args_tree, **_):
ref, indexers, _, _ = args_tree.unflatten(avals_flat)
return (
jax_core.ShapedArray(indexers[-1].get_indexer_shape(), ref.dtype),
{state.ReadEffect(0)},
)
def _load_pp_rule(eqn, context, settings):
# Pretty prints `a = load x i` as `x[i] <- a`
y, = eqn.outvars
x, indexers, mask, other = tree_util.tree_unflatten(eqn.params["args_tree"],
eqn.invars)
# TODO(sharadmv): pretty print mask and other
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
result = [
lhs,
pp.text(' <- '),
sp.pp_ref_indexers(context, x, indexers)
]
if mask is not None:
result += [
pp.text(" "),
pp.text("mask="),
pp.text(jax_core.pp_var(mask, context)),
]
if other is not None:
result += [
pp.text(" "),
pp.text("other="),
pp.text(jax_core.pp_var(other, context)),
]
return pp.concat(result)
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
def _load_jvp(primals, tangents, args_tree, **params):
ref_primal, indexers, mask, other_primal = args_tree.unflatten(primals)
ref_tangent, _, _, other_tangent = args_tree.unflatten(tangents)
if other_tangent is not None:
other_tangent = ad_util.instantiate(other_tangent)
return (
load_p.bind(
*tree_util.tree_leaves((ref_primal, indexers, mask, other_primal)),
args_tree=args_tree,
**params,
),
load_p.bind(
*tree_util.tree_leaves((ref_tangent, indexers, mask, other_tangent)),
args_tree=args_tree,
**params,
),
)
ad.primitive_jvps[load_p] = _load_jvp
def uninitialized_value(shape, dtype):
if jnp.issubdtype(dtype, jnp.floating):
return jnp.full(shape, jnp.nan, dtype)
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
elif jnp.issubdtype(dtype, jnp.bool):
return jnp.full(shape, False, dtype)
raise NotImplementedError(dtype)
def _pad_values_to_avoid_dynamic_slice_oob_shift(value,
slice_sizes, unpad=False):
"""
DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the
requested slice overruns the bounds of the array. This pads the array with
uninitialised values such that the requested slice will never overrun.
For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is
requested then the result will be [3.,4.,NaN,NaN] after padding, rather than
[1.,2.,3.,4.] from the unpadded array
unpad=True performs the inverse operation
"""
padding_config = tuple((0, slice_size, 0) for slice_size in slice_sizes)
if unpad:
padding_config = tuple((-low, -high, -interior)
for (low, high, interior) in padding_config)
padding_value = uninitialized_value(shape=(), dtype=value.dtype)
value = lax.pad(value,
padding_config=padding_config,
padding_value=padding_value)
return value
_unpad_values_to_avoid_dynamic_slice_oob_shift = partial(
_pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True)
@state_discharge.register_discharge_rule(load_p)
def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
del out_avals # Unused.
ref, indexers, mask, other = args_tree.unflatten(args_flat)
# TODO(sharadmv): add support for multiple indexers
if len(indexers) > 1:
raise NotImplementedError("Only one indexer supported in discharge rule.")
idx = indexers[0]
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
# TODO(b/329733289): support strided load/store in interpret mode.
for s in idx.indices:
if isinstance(s, Slice) and s.stride > 1:
raise NotImplementedError("Unimplemented stride support.")
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
# fixes an inconstency with lax.dynamic_slice where if the slice goes out
# of bounds, it will instead move the start_index backwards so the slice
# will fit in memory.
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
idx_dtype = dtypes.canonicalize_dtype(jnp.int64)
out_ones = lax.dynamic_slice(
ref,
[jnp.astype(s, idx_dtype) for s in slice_starts],
slice_sizes=slice_sizes,
)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
else:
raise NotImplementedError
if mask is not None and other is not None:
out = jnp.where(mask, out, other)
return (None,) * len(in_avals), out
swap_p = jax_core.Primitive('masked_swap')
@swap_p.def_effectful_abstract_eval
def _swap_abstract_eval(*avals_flat, args_tree, **_):
ref, indexers, val, _ = args_tree.unflatten(avals_flat)
expected_output_shape = indexers[-1].get_indexer_shape()
if expected_output_shape != val.shape:
raise ValueError(
f"Invalid shape for `swap`. Ref shape: {ref.shape}. "
f"Value shape: {val.shape}. Indices: {indexers}. "
)
if ref.dtype != val.dtype:
raise ValueError(
f"Invalid dtype for `swap`. Ref dtype: {ref.dtype}. "
f"Value dtype: {val.dtype}. "
)
return (
jax_core.ShapedArray(expected_output_shape, ref.dtype),
{state.WriteEffect(0)},
)
def _swap_pp_rule(eqn, context, settings):
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
# or:
# Pretty prints `_ = swap x v i` as `x[i] <- v`
y, = eqn.outvars
x, indexers, val, mask = eqn.params["args_tree"].unflatten(eqn.invars)
x_i = sp.pp_ref_indexers(context, x, indexers)
if isinstance(y, jax_core.DropVar):
return pp.concat([
x_i,
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
result = [
y,
pp.text(", "),
x_i,
pp.text(" <- "),
x_i,
pp.text(", "),
pp.text(jax_core.pp_var(val, context)),
]
if mask is not None:
result += [
pp.text(" "),
pp.text("mask="),
pp.text(jax_core.pp_var(mask, context)),
]
return pp.concat(result)
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _swap_jvp(primals, tangents, *, args_tree, **params):
ref_primal, indexers, val_primal, mask = args_tree.unflatten(primals)
ref_tangent, _, val_tangent, _ = args_tree.unflatten(tangents)
val_tangent = ad_util.instantiate(val_tangent)
return (
swap_p.bind(
*tree_util.tree_leaves((ref_primal, indexers, val_primal, mask)),
args_tree=args_tree,
**params,
),
swap_p.bind(
*tree_util.tree_leaves((ref_tangent, indexers, val_tangent, mask)),
args_tree=args_tree,
**params,
),
)
ad.primitive_jvps[swap_p] = _swap_jvp
@state_discharge.register_discharge_rule(swap_p)
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
del out_avals # Unused.
ref, indexers, val, mask = args_tree.unflatten(args_flat)
if len(indexers) > 1:
raise NotImplementedError("Only one indexer supported in discharge rule.")
idx = indexers[0]
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
# TODO(b/329733289): support strided load/store in interpret mode.
for s in idx.indices:
if isinstance(s, Slice) and s.stride > 1:
raise NotImplementedError("Unimplemented stride support.")
indices = idx.indices
scalar_dims = [
i
for i, s in enumerate(indices)
if not isinstance(s, Slice) and not s.shape
]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
# Fixes an inconsistency with lax.dynamic_update_slice where if the slice
# goes out of bounds, it will instead move the start_index backwards so the
# slice will fit in memory.
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
out = jnp.squeeze(out, scalar_dims)
if mask is not None:
out_ = out
out = jnp.where(mask, out, val)
val = jnp.where(mask, val, out_)
val = jnp.expand_dims(val, scalar_dims)
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift(x_new, slice_sizes)
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
if mask is not None:
out_ = out
out = jnp.where(mask, out, val)
val = jnp.where(mask, val, out_)
x_new = ref.at[idx.indices].set(val)
else:
raise NotImplementedError
return (x_new,) + (None,) * (len(in_avals) - 1), out
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
eviction_policy=None, volatile=False) -> jax.Array:
"""Returns an array loaded from the given index.
If neither ``mask`` nor ``other`` is specified, this function has the same
semantics as ``x_ref_or_view[idx]`` in JAX.
Args:
x_ref_or_view: The ref to load from.
idx: The indexer to use.
mask: An optional boolean mask specifying which indices to load.
If mask is ``False`` and ``other`` is not given, no assumptions can
be made about the value in the resulting array.
other: An optional value to use for indices where mask is ``False``.
cache_modifier: TO BE DOCUMENTED.
eviction_policy: TO BE DOCUMENTED.
volatile: TO BE DOCUMENTED.
"""
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
return load_p.bind(
*args_flat,
args_tree=args_tree,
cache_modifier=cache_modifier,
eviction_policy=eviction_policy,
is_volatile=volatile,
)
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
_function_name="swap") -> jax.Array:
"""Swaps the value at the given index and returns the old value.
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
Returns:
The value stored in the ref prior to the swap.
"""
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
return swap_p.bind(
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
)
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
"""Stores a value at the given index.
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
"""
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
_function_name="store")
def dot(a, b, trans_a: bool = False, trans_b: bool = False,
allow_tf32: bool | None = None, precision=None):
if (a.ndim != 2) or (b.ndim != 2):
raise ValueError("`a` and `b` must be 2D arrays.")
lhs_contract_dim = 0 if trans_a else 1
rhs_contract_dim = 0 if not trans_b else 1
if allow_tf32 is not None:
if precision is not None:
raise ValueError("Only one of allow_tf32 and precision can be specified")
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
return jax.lax.dot_general(
a,
b,
dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=precision,
preferred_element_type=jnp.float32,
)
class PrintEffect(effects.Effect):
__str__ = lambda self: "Print"
debug_print_effect = PrintEffect()
# TODO(slebedev): Consider making the effect ordered.
effects.lowerable_effects.add_type(PrintEffect)
effects.control_flow_allowed_effects.add_type(PrintEffect)
effects.remat_allowed_effects.add_type(PrintEffect)
effects.custom_derivatives_allowed_effects.add_type(PrintEffect)
debug_print_p = jax_core.Primitive("debug_print")
debug_print_p.multiple_results = True
def debug_print(fmt: str, *args: jax.ArrayLike):
"""Prints scalar values from inside a Pallas kernel.
Args:
fmt: A format string to be included in the output. The restrictions on the
format string depend on the backend:
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
(``{...}``), since it is always printed before any of the values.
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string.
*args: The scalar values to print.
""" # fmt: skip
has_placeholders = False
if fmt:
_, field_name, *_ = next(iter(string.Formatter().parse(fmt)))
has_placeholders = field_name is not None
return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders)
def check_debug_print_format(
fmt: str, *args: jax.ArrayLike
):
n_placeholders = 0
for _, field, spec, conversion in string.Formatter().parse(fmt):
if field is not None:
n_placeholders += 1
if spec or conversion:
raise ValueError(
"The format string should not contain any format specs or conversions"
)
if field:
raise ValueError(
"The format string should not reference arguments by position or name"
)
if len(args) != n_placeholders:
raise TypeError(
f"The format string expects {n_placeholders} "
f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}"
)
@debug_print_p.def_impl
def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool):
if has_placeholders:
print(fmt.format(*args))
else:
print(fmt, *args)
return ()
@debug_print_p.def_effectful_abstract_eval
def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool):
del fmt, has_placeholders
if any(aval.shape for aval in avals):
raise ValueError("Only scalar values are supported")
return [], {debug_print_effect}
def debug_print_batching_rule(args, dims, **params):
"""Unrolls the print primitive across the mapped axis."""
axis_size = next(x.shape[i] for x, i in zip(args, dims) if i is not None)
# TODO(sharadmv): implement in terms of rolled loop unstead of unrolled.
def get_arg_at_dim(i, dim, arg):
if dim is batching.not_mapped:
# Broadcast unmapped argument
return arg
return lax.index_in_dim(arg, i, axis=dim, keepdims=False)
outs = []
for i in range(axis_size):
args_idx = map(functools.partial(get_arg_at_dim, i), dims, args)
outs.append(debug_print_p.bind(*args_idx, **params))
outs = [jnp.stack(xs) for xs in zip(*outs)]
return outs, (0,) * len(outs)
batching.primitive_batchers[debug_print_p] = functools.partial(
debug_print_batching_rule, debug_print_p
)
@functools.partial(mlir.register_lowering, debug_print_p)
def debug_print_lowering_rule(ctx, *args, **params):
result, _, _ = mlir.emit_python_callback(
ctx,
functools.partial(debug_print_p.impl, **params),
None,
list(args),
ctx.avals_in,
ctx.avals_out,
has_side_effect=True,
)
return result
run_scoped_p = jax_core.Primitive("run_scoped")
run_scoped_p.multiple_results = True
def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
"""Call the function with allocated references.
Args:
f: The function that generates the jaxpr.
*types: The types of the function's positional arguments.
**kw_types: The types of the function's keyword arguments.
"""
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = [t.get_aval() for t in flat_types]
# Turn the function into a jaxpr. The body of run_scoped may have
# effects (IO) on constvars (i.e. variables inherited from the
# parent scope). Jax can't reason about effects to references that
# are not in the invars of an operation so we just put them all
# there.
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
out = run_scoped_p.bind(*consts, jaxpr=jaxpr)
return tree_util.tree_unflatten(out_tree_thunk(), out)
@run_scoped_p.def_effectful_abstract_eval
def _run_scoped_abstract_eval(*args, jaxpr):
del args
# jaxpr will have effects for its inputs (Refs that are allocated) and for
# constvars (closed over Refs). The effects for the allocated Refs are local
# to the jaxpr and shouldn't propagate out.
nonlocal_effects = {
eff
for eff in jaxpr.effects
if not (
isinstance(eff, effects.JaxprInputEffect)
and eff.input_index >= len(jaxpr.constvars)
)
}
return [v.aval for v in jaxpr.outvars], nonlocal_effects