mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
948 lines
32 KiB
Python
948 lines
32 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Pallas-specific JAX primitives."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
import functools
|
|
import string
|
|
from typing import Any, Callable
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax import tree_util
|
|
from jax._src import ad_util
|
|
from jax._src import api_util
|
|
from jax._src import core as jax_core
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src import linear_util as lu
|
|
from jax._src import pretty_printer as pp
|
|
from jax._src import state
|
|
from jax._src import util
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.state import indexing
|
|
from jax._src.state import types as state_types
|
|
from jax._src.state import primitives as sp
|
|
from jax.interpreters import mlir
|
|
import jax.numpy as jnp
|
|
|
|
partial = functools.partial
|
|
Slice = indexing.Slice
|
|
NDIndexer = indexing.NDIndexer
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
program_id_p = jax_core.Primitive("program_id")
|
|
batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule
|
|
|
|
def program_id(axis: int) -> jax.Array:
|
|
"""Returns the kernel execution position along the given axis of the grid.
|
|
|
|
For example, with a 2D `grid` in the kernel execution corresponding to the
|
|
grid coordinates `(1, 2)`,
|
|
`program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
|
|
|
|
The returned value is an array of shape `()` and dtype `int32`.
|
|
|
|
Args:
|
|
axis: the axis of the grid along which to count the program.
|
|
"""
|
|
return program_id_p.bind(axis=axis)
|
|
|
|
def program_id_bind_with_trace(trace, _, params):
|
|
axis = params.pop("axis")
|
|
grid_env = pallas_core.current_grid_env()
|
|
if grid_env:
|
|
return grid_env[axis].index
|
|
frame = pallas_core.axis_frame()
|
|
# Query the size of the axis to make sure it's a valid axis (and error
|
|
# otherwise).
|
|
_ = frame.size(axis)
|
|
return jax_core.Primitive.bind_with_trace(program_id_p, trace, (), dict(axis=axis))
|
|
# TODO(dougalm): figure out how put the grid_env contest on the relevant trace
|
|
program_id_p.def_bind_with_trace(program_id_bind_with_trace)
|
|
|
|
@program_id_p.def_abstract_eval
|
|
def _program_id_abstract_eval(**_):
|
|
return jax_core.ShapedArray((), jnp.int32)
|
|
|
|
num_programs_p = jax_core.Primitive("num_programs")
|
|
|
|
def num_programs(axis: int) -> int | jax.Array:
|
|
"""Returns the size of the grid along the given axis."""
|
|
return num_programs_p.bind(axis=axis)
|
|
|
|
def _num_programs_bind_with_trace(trace, _, params):
|
|
axis = params.pop("axis")
|
|
# We might be using a local grid env
|
|
grid_env = pallas_core.current_grid_env()
|
|
if grid_env:
|
|
return grid_env[axis].size
|
|
# Otherwise, we look up the size of the grid in the axis env
|
|
frame = pallas_core.axis_frame()
|
|
size = frame.size(axis)
|
|
if size is pallas_core.dynamic_grid_dim:
|
|
return jax_core.Primitive.bind_with_trace(num_programs_p, trace, (), dict(axis=axis))
|
|
return size
|
|
num_programs_p.def_bind_with_trace(_num_programs_bind_with_trace)
|
|
|
|
@num_programs_p.def_abstract_eval
|
|
def _num_programs_abstract_eval(**_):
|
|
return jax_core.ShapedArray((), jnp.int32)
|
|
|
|
class AtomicOpType(enum.Enum):
|
|
XCHG = "xchg"
|
|
ADD = "add"
|
|
MAX = "max"
|
|
MIN = "min"
|
|
AND = "and"
|
|
OR = "or"
|
|
XOR = "xor"
|
|
|
|
atomic_rmw_p = jax_core.Primitive("atomic_rmw")
|
|
|
|
|
|
def _atomic_rmw_discharge_rule(
|
|
in_avals, out_avals, *args_flat, args_tree, atomic_type: AtomicOpType
|
|
):
|
|
del out_avals # Unused.
|
|
ref, indexers, val, mask = args_tree.unflatten(args_flat)
|
|
if len(indexers) > 1:
|
|
raise NotImplementedError("Only one indexer is supported.")
|
|
idx = indexers[0]
|
|
|
|
if mask is not None:
|
|
raise NotImplementedError
|
|
|
|
if atomic_type == AtomicOpType.ADD:
|
|
monoid = lambda x, y: x + y
|
|
elif atomic_type == AtomicOpType.MAX:
|
|
monoid = jnp.maximum
|
|
elif atomic_type == AtomicOpType.MIN:
|
|
monoid = jnp.minimum
|
|
else:
|
|
raise NotImplementedError(atomic_type)
|
|
|
|
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
|
|
indices = idx.indices
|
|
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
|
|
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
|
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
|
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
|
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
|
|
val = val[val_indexer]
|
|
val = monoid(val, out_ones)
|
|
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
|
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
|
out = out_ones[out_indexer]
|
|
elif all(not isinstance(s, Slice) for s in idx.indices):
|
|
out = ref[idx.indices]
|
|
x_new = ref.at[idx.indices].set(monoid(out, val))
|
|
else:
|
|
raise NotImplementedError
|
|
return (x_new,) + (None,) * (len(in_avals) - 1), out
|
|
|
|
|
|
state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule)
|
|
|
|
|
|
@atomic_rmw_p.def_effectful_abstract_eval
|
|
def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
|
|
ref, _, _, _ = args_tree.unflatten(avals_flat)
|
|
if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
|
|
raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.")
|
|
if ref.dtype in {
|
|
jnp.dtype("bool"),
|
|
jnp.dtype("int8"),
|
|
jnp.dtype("int16"),
|
|
jnp.bfloat16,
|
|
}:
|
|
raise ValueError(
|
|
f"`atomic_{atomic_type.value}` does not support {ref.dtype}."
|
|
)
|
|
return _swap_abstract_eval(*avals_flat, args_tree=args_tree)
|
|
|
|
|
|
def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
|
|
atomic_type: AtomicOpType):
|
|
x_ref, transforms = sp.get_ref_and_transforms(
|
|
x_ref_or_view, idx, "atomic_rmw"
|
|
)
|
|
args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask))
|
|
return atomic_rmw_p.bind(
|
|
*args_flat, args_tree=args_tree, atomic_type=atomic_type
|
|
)
|
|
|
|
def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically exchanges the given value with the value at the given index.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the aupdate.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG
|
|
)
|
|
|
|
|
|
def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically computes ``x_ref_or_view[idx] += val``.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD
|
|
)
|
|
|
|
|
|
def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX
|
|
)
|
|
|
|
|
|
def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN
|
|
)
|
|
|
|
|
|
def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically computes ``x_ref_or_view[idx] &= val``.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND
|
|
)
|
|
|
|
|
|
def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically computes ``x_ref_or_view[idx] |= val``.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR
|
|
)
|
|
|
|
|
|
def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
|
"""Atomically computes ``x_ref_or_view[idx] ^= val``.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to operate on.
|
|
idx: The indexer to use.
|
|
mask: TO BE DOCUMENTED.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return _atomic_rmw(
|
|
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR
|
|
)
|
|
|
|
atomic_cas_p = jax_core.Primitive("atomic_cas")
|
|
|
|
@atomic_cas_p.def_effectful_abstract_eval
|
|
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
|
|
if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape:
|
|
raise ValueError("cmp and val must have identical dtypes and shapes")
|
|
if ref_aval.shape:
|
|
raise ValueError("ref must be scalar.")
|
|
if cmp_aval.shape:
|
|
raise ValueError("cmp must be scalar.")
|
|
if val_aval.shape:
|
|
raise ValueError("val must be scalar.")
|
|
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)}
|
|
|
|
|
|
def atomic_cas(ref, cmp, val):
|
|
"""Performs an atomic compare-and-swap of the value in the ref with the
|
|
given value.
|
|
|
|
Args:
|
|
ref: The ref to operate on.
|
|
cmp: The expected value to compare against.
|
|
val: The value to swap in.
|
|
|
|
Returns:
|
|
The value at the given index prior to the atomic operation.
|
|
"""
|
|
return atomic_cas_p.bind(ref, cmp, val)
|
|
|
|
@state_discharge.register_discharge_rule(atomic_cas_p)
|
|
def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val):
|
|
del in_avals, out_avals
|
|
new_val = jnp.where(ref == cmp, val, ref)
|
|
return (new_val, None, None), ref
|
|
|
|
max_contiguous_p = jax_core.Primitive("max_contiguous")
|
|
|
|
max_contiguous_p.def_impl(lambda x, **_: x)
|
|
mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x])
|
|
|
|
def max_contiguous(x, values):
|
|
if not isinstance(values, list):
|
|
values = [values]
|
|
return max_contiguous_p.bind(x, values=values)
|
|
|
|
@max_contiguous_p.def_abstract_eval
|
|
def _max_contiguous_abstract_eval(aval, **_):
|
|
return aval
|
|
|
|
multiple_of_p = jax_core.Primitive("multiple_of")
|
|
|
|
multiple_of_p.def_impl(lambda x, **_: x)
|
|
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
|
|
|
|
def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array:
|
|
if not isinstance(values, list):
|
|
values = [values]
|
|
return multiple_of_p.bind(x, values=values)
|
|
|
|
@multiple_of_p.def_abstract_eval
|
|
def _multiple_of_abstract_eval(aval, **_):
|
|
return aval
|
|
|
|
load_p = jax_core.Primitive('masked_load')
|
|
|
|
|
|
@load_p.def_effectful_abstract_eval
|
|
def _load_abstract_eval(*avals_flat, args_tree, **_):
|
|
ref, indexers, _, _ = args_tree.unflatten(avals_flat)
|
|
return (
|
|
jax_core.ShapedArray(indexers[-1].get_indexer_shape(), ref.dtype),
|
|
{state.ReadEffect(0)},
|
|
)
|
|
|
|
|
|
def _load_pp_rule(eqn, context, settings):
|
|
# Pretty prints `a = load x i` as `x[i] <- a`
|
|
y, = eqn.outvars
|
|
x, indexers, mask, other = tree_util.tree_unflatten(eqn.params["args_tree"],
|
|
eqn.invars)
|
|
# TODO(sharadmv): pretty print mask and other
|
|
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
|
result = [
|
|
lhs,
|
|
pp.text(' <- '),
|
|
sp.pp_ref_transforms(context, x, indexers)
|
|
]
|
|
if mask is not None:
|
|
result += [
|
|
pp.text(" "),
|
|
pp.text("mask="),
|
|
pp.text(jax_core.pp_var(mask, context)),
|
|
]
|
|
if other is not None:
|
|
result += [
|
|
pp.text(" "),
|
|
pp.text("other="),
|
|
pp.text(jax_core.pp_var(other, context)),
|
|
]
|
|
return pp.concat(result)
|
|
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
|
|
|
|
|
|
def _load_jvp(primals, tangents, args_tree, **params):
|
|
ref_primal, indexers, mask, other_primal = args_tree.unflatten(primals)
|
|
ref_tangent, _, _, other_tangent = args_tree.unflatten(tangents)
|
|
if other_tangent is not None:
|
|
other_tangent = ad_util.instantiate(other_tangent)
|
|
return (
|
|
load_p.bind(
|
|
*tree_util.tree_leaves((ref_primal, indexers, mask, other_primal)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
load_p.bind(
|
|
*tree_util.tree_leaves((ref_tangent, indexers, mask, other_tangent)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
)
|
|
|
|
|
|
ad.primitive_jvps[load_p] = _load_jvp
|
|
|
|
def uninitialized_value(shape, dtype):
|
|
if jnp.issubdtype(dtype, jnp.floating):
|
|
return jnp.full(shape, jnp.nan, dtype)
|
|
# Note: Currently semaphore is i16[], meaning this case needs to be
|
|
# handled before the general case for integers.
|
|
# TODO(justinfu): Handle semaphores with a custom extended dtype.
|
|
elif jnp.issubdtype(dtype, pallas_core.SEMAPHORE_INTERPRET_DTYPE):
|
|
return jnp.full(shape, 0, dtype)
|
|
elif jnp.issubdtype(dtype, jnp.integer):
|
|
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
|
|
elif jnp.issubdtype(dtype, jnp.bool):
|
|
return jnp.full(shape, False, dtype)
|
|
elif jnp.issubdtype(dtype, pallas_core.semaphore_dtype):
|
|
return jnp.full(shape, 0, dtype)
|
|
raise NotImplementedError(dtype)
|
|
|
|
def _pad_values_to_avoid_dynamic_slice_oob_shift(value,
|
|
slice_sizes, unpad=False):
|
|
"""
|
|
DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the
|
|
requested slice overruns the bounds of the array. This pads the array with
|
|
uninitialised values such that the requested slice will never overrun.
|
|
|
|
For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is
|
|
requested then the result will be [3.,4.,NaN,NaN] after padding, rather than
|
|
[1.,2.,3.,4.] from the unpadded array
|
|
|
|
unpad=True performs the inverse operation
|
|
"""
|
|
|
|
padding_config = tuple((0, slice_size, 0) for slice_size in slice_sizes)
|
|
if unpad:
|
|
padding_config = tuple((-low, -high, -interior)
|
|
for (low, high, interior) in padding_config)
|
|
padding_value = uninitialized_value(shape=(), dtype=value.dtype)
|
|
value = lax.pad(value,
|
|
padding_config=padding_config,
|
|
padding_value=padding_value)
|
|
return value
|
|
|
|
_unpad_values_to_avoid_dynamic_slice_oob_shift = partial(
|
|
_pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True)
|
|
|
|
|
|
@state_discharge.register_discharge_rule(load_p)
|
|
def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
|
del out_avals # Unused.
|
|
ref, indexers, mask, other = args_tree.unflatten(args_flat)
|
|
# TODO(sharadmv): add support for multiple indexers
|
|
if len(indexers) > 1:
|
|
raise NotImplementedError("Only one indexer supported in discharge rule.")
|
|
idx = indexers[0]
|
|
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
|
|
# TODO(ayx): support strided load/store in interpret mode.
|
|
for s in idx.indices:
|
|
if isinstance(s, Slice) and s.stride > 1:
|
|
raise NotImplementedError("Unimplemented stride support.")
|
|
indices = idx.indices
|
|
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
|
|
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
|
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
|
# fixes an inconstency with lax.dynamic_slice where if the slice goes out
|
|
# of bounds, it will instead move the start_index backwards so the slice
|
|
# will fit in memory.
|
|
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
|
|
idx_dtype = dtypes.canonicalize_dtype(jnp.int64)
|
|
out_ones = lax.dynamic_slice(
|
|
ref,
|
|
[jnp.astype(s, idx_dtype) for s in slice_starts],
|
|
slice_sizes=slice_sizes,
|
|
)
|
|
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
|
out = out_ones[out_indexer]
|
|
elif all(not isinstance(s, Slice) for s in idx.indices):
|
|
out = ref[idx.indices]
|
|
else:
|
|
raise NotImplementedError
|
|
if mask is not None and other is not None:
|
|
out = jnp.where(mask, out, other)
|
|
return (None,) * len(in_avals), out
|
|
|
|
|
|
swap_p = jax_core.Primitive('masked_swap')
|
|
|
|
|
|
@swap_p.def_effectful_abstract_eval
|
|
def _swap_abstract_eval(*avals_flat, args_tree, **_):
|
|
ref, indexers, val, _ = args_tree.unflatten(avals_flat)
|
|
expected_output_shape = indexers[-1].get_indexer_shape()
|
|
if expected_output_shape != val.shape:
|
|
raise ValueError(
|
|
f"Invalid shape for `swap`. Ref shape: {ref.shape}. "
|
|
f"Value shape: {val.shape}. Indices: {indexers}. "
|
|
)
|
|
if ref.dtype != val.dtype:
|
|
raise ValueError(
|
|
f"Invalid dtype for `swap`. Ref dtype: {ref.dtype}. "
|
|
f"Value dtype: {val.dtype}. "
|
|
)
|
|
return (
|
|
jax_core.ShapedArray(expected_output_shape, ref.dtype),
|
|
{state.WriteEffect(0)},
|
|
)
|
|
|
|
|
|
def _swap_pp_rule(eqn, context, settings):
|
|
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
|
|
# or:
|
|
# Pretty prints `_ = swap x v i` as `x[i] <- v`
|
|
y, = eqn.outvars
|
|
x, indexers, val, mask = eqn.params["args_tree"].unflatten(eqn.invars)
|
|
x_i = sp.pp_ref_transforms(context, x, indexers)
|
|
if isinstance(y, jax_core.DropVar):
|
|
return pp.concat([
|
|
x_i,
|
|
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
|
|
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
|
result = [
|
|
y,
|
|
pp.text(", "),
|
|
x_i,
|
|
pp.text(" <- "),
|
|
x_i,
|
|
pp.text(", "),
|
|
pp.text(jax_core.pp_var(val, context)),
|
|
]
|
|
if mask is not None:
|
|
result += [
|
|
pp.text(" "),
|
|
pp.text("mask="),
|
|
pp.text(jax_core.pp_var(mask, context)),
|
|
]
|
|
return pp.concat(result)
|
|
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
|
|
|
|
|
def _swap_jvp(primals, tangents, *, args_tree, **params):
|
|
ref_primal, indexers, val_primal, mask = args_tree.unflatten(primals)
|
|
ref_tangent, _, val_tangent, _ = args_tree.unflatten(tangents)
|
|
val_tangent = ad_util.instantiate(val_tangent)
|
|
return (
|
|
swap_p.bind(
|
|
*tree_util.tree_leaves((ref_primal, indexers, val_primal, mask)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
swap_p.bind(
|
|
*tree_util.tree_leaves((ref_tangent, indexers, val_tangent, mask)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
)
|
|
|
|
|
|
ad.primitive_jvps[swap_p] = _swap_jvp
|
|
|
|
|
|
@state_discharge.register_discharge_rule(swap_p)
|
|
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
|
del out_avals # Unused.
|
|
ref, indexers, val, mask = args_tree.unflatten(args_flat)
|
|
if len(indexers) > 1:
|
|
raise NotImplementedError("Only one indexer supported in discharge rule.")
|
|
idx = indexers[0]
|
|
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
|
|
# TODO(ayx): support strided load/store in interpret mode.
|
|
for s in idx.indices:
|
|
if isinstance(s, Slice) and s.stride > 1:
|
|
raise NotImplementedError("Unimplemented stride support.")
|
|
indices = idx.indices
|
|
scalar_dims = [
|
|
i
|
|
for i, s in enumerate(indices)
|
|
if not isinstance(s, Slice) and not s.shape
|
|
]
|
|
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
|
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
|
# Fixes an inconsistency with lax.dynamic_update_slice where if the slice
|
|
# goes out of bounds, it will instead move the start_index backwards so the
|
|
# slice will fit in memory.
|
|
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
|
|
out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
|
out = jnp.squeeze(out, scalar_dims)
|
|
if mask is not None:
|
|
out_ = out
|
|
out = jnp.where(mask, out, val)
|
|
val = jnp.where(mask, val, out_)
|
|
val = jnp.expand_dims(val, scalar_dims)
|
|
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
|
x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift(x_new, slice_sizes)
|
|
elif all(not isinstance(s, Slice) for s in idx.indices):
|
|
out = ref[idx.indices]
|
|
if mask is not None:
|
|
out_ = out
|
|
out = jnp.where(mask, out, val)
|
|
val = jnp.where(mask, val, out_)
|
|
x_new = ref.at[idx.indices].set(val)
|
|
else:
|
|
raise NotImplementedError
|
|
return (x_new,) + (None,) * (len(in_avals) - 1), out
|
|
|
|
|
|
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
|
|
eviction_policy=None, volatile=False) -> jax.Array:
|
|
"""Returns an array loaded from the given index.
|
|
|
|
If neither ``mask`` nor ``other`` is specified, this function has the same
|
|
semantics as ``x_ref_or_view[idx]`` in JAX.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to load from.
|
|
idx: The indexer to use.
|
|
mask: An optional boolean mask specifying which indices to load.
|
|
If mask is ``False`` and ``other`` is not given, no assumptions can
|
|
be made about the value in the resulting array.
|
|
other: An optional value to use for indices where mask is ``False``.
|
|
cache_modifier: TO BE DOCUMENTED.
|
|
eviction_policy: TO BE DOCUMENTED.
|
|
volatile: TO BE DOCUMENTED.
|
|
"""
|
|
x_ref, transforms = sp.get_ref_and_transforms(x_ref_or_view, idx, "load")
|
|
args_flat, args_tree = tree_util.tree_flatten(
|
|
(x_ref, transforms, mask, other)
|
|
)
|
|
return load_p.bind(
|
|
*args_flat,
|
|
args_tree=args_tree,
|
|
cache_modifier=cache_modifier,
|
|
eviction_policy=eviction_policy,
|
|
is_volatile=volatile,
|
|
)
|
|
|
|
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
|
|
_function_name="swap") -> jax.Array:
|
|
"""Swaps the value at the given index and returns the old value.
|
|
|
|
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
|
|
|
Returns:
|
|
The value stored in the ref prior to the swap.
|
|
"""
|
|
x_ref, transforms = sp.get_ref_and_transforms(
|
|
x_ref_or_view, idx, _function_name
|
|
)
|
|
args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask))
|
|
return swap_p.bind(
|
|
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
|
|
)
|
|
|
|
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
|
|
"""Stores a value at the given index.
|
|
|
|
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
|
"""
|
|
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
|
|
_function_name="store")
|
|
|
|
def dot(a, b, trans_a: bool = False, trans_b: bool = False,
|
|
allow_tf32: bool | None = None, precision=None):
|
|
if (a.ndim != 2) or (b.ndim != 2):
|
|
raise ValueError("`a` and `b` must be 2D arrays.")
|
|
lhs_contract_dim = 0 if trans_a else 1
|
|
rhs_contract_dim = 0 if not trans_b else 1
|
|
if allow_tf32 is not None:
|
|
if precision is not None:
|
|
raise ValueError("Only one of allow_tf32 and precision can be specified")
|
|
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
|
|
return jax.lax.dot_general(
|
|
a,
|
|
b,
|
|
dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
|
|
precision=precision,
|
|
preferred_element_type=jnp.float32,
|
|
)
|
|
|
|
|
|
class PrintEffect(effects.Effect):
|
|
__str__ = lambda self: "Print"
|
|
|
|
|
|
debug_print_effect = PrintEffect()
|
|
|
|
# TODO(slebedev): Consider making the effect ordered.
|
|
effects.lowerable_effects.add_type(PrintEffect)
|
|
effects.control_flow_allowed_effects.add_type(PrintEffect)
|
|
effects.remat_allowed_effects.add_type(PrintEffect)
|
|
effects.custom_derivatives_allowed_effects.add_type(PrintEffect)
|
|
|
|
|
|
debug_print_p = jax_core.Primitive("debug_print")
|
|
debug_print_p.multiple_results = True
|
|
|
|
|
|
def debug_print(fmt: str, *args: jax.typing.ArrayLike):
|
|
"""Prints values from inside a Pallas kernel.
|
|
|
|
Args:
|
|
fmt: A format string to be included in the output. The restrictions on the
|
|
format string depend on the backend:
|
|
|
|
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
|
|
(``{...}``), since it is always printed before any of the values.
|
|
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
|
|
contain a placeholder for each value to be printed. Format specs and
|
|
conversions are not supported. All values must be scalars.
|
|
* On TPU, if all inputs are scalars: If ``fmt`` contains placeholders,
|
|
all values must be 32-bit integers. If there are no placeholders, the
|
|
values are printed after the format string.
|
|
* On TPU, if the input is a single vector, the vector is printed after
|
|
the format string. The format string must end with a single placeholder
|
|
``{}``.
|
|
*args: The values to print.
|
|
""" # fmt: skip
|
|
has_placeholders = False
|
|
if fmt:
|
|
_, field_name, *_ = next(iter(string.Formatter().parse(fmt)))
|
|
has_placeholders = field_name is not None
|
|
return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders)
|
|
|
|
|
|
def check_debug_print_format(
|
|
fmt: str, *args: jax.typing.ArrayLike
|
|
):
|
|
n_placeholders = 0
|
|
for _, field, spec, conversion in string.Formatter().parse(fmt):
|
|
if field is not None:
|
|
n_placeholders += 1
|
|
if spec or conversion:
|
|
raise ValueError(
|
|
"The format string should not contain any format specs or conversions"
|
|
)
|
|
if field:
|
|
raise ValueError(
|
|
"The format string should not reference arguments by position or name"
|
|
)
|
|
|
|
if len(args) != n_placeholders:
|
|
raise TypeError(
|
|
f"The format string expects {n_placeholders} "
|
|
f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}"
|
|
)
|
|
|
|
|
|
@debug_print_p.def_impl
|
|
def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool):
|
|
if has_placeholders:
|
|
print(fmt.format(*args))
|
|
else:
|
|
print(fmt, *args)
|
|
return ()
|
|
|
|
|
|
@debug_print_p.def_effectful_abstract_eval
|
|
def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool):
|
|
del avals, fmt, has_placeholders # Unused.
|
|
return [], {debug_print_effect}
|
|
|
|
|
|
def debug_print_batching_rule(args, dims, **params):
|
|
"""Unrolls the print primitive across the mapped axis."""
|
|
axis_size = next(x.shape[i] for x, i in zip(args, dims) if i is not None)
|
|
|
|
# TODO(sharadmv): implement in terms of rolled loop unstead of unrolled.
|
|
def get_arg_at_dim(i, dim, arg):
|
|
if dim is batching.not_mapped:
|
|
# Broadcast unmapped argument
|
|
return arg
|
|
return lax.index_in_dim(arg, i, axis=dim, keepdims=False)
|
|
|
|
outs = []
|
|
for i in range(axis_size):
|
|
args_idx = map(functools.partial(get_arg_at_dim, i), dims, args)
|
|
outs.append(debug_print_p.bind(*args_idx, **params))
|
|
outs = [jnp.stack(xs) for xs in zip(*outs)]
|
|
return outs, (0,) * len(outs)
|
|
|
|
|
|
batching.primitive_batchers[debug_print_p] = functools.partial(
|
|
debug_print_batching_rule, debug_print_p
|
|
)
|
|
|
|
|
|
@functools.partial(mlir.register_lowering, debug_print_p)
|
|
def debug_print_lowering_rule(ctx, *args, **params):
|
|
result, _, _ = mlir.emit_python_callback(
|
|
ctx,
|
|
functools.partial(debug_print_p.impl, **params),
|
|
None,
|
|
list(args),
|
|
ctx.avals_in,
|
|
ctx.avals_out,
|
|
has_side_effect=True,
|
|
)
|
|
return result
|
|
|
|
|
|
# All of those shenanigans are because we can't make TransformedRef a PyTree,
|
|
# because they should appear as atomic JAX values to the users.
|
|
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
|
|
# inferred by the compiler.
|
|
@lu.transformation2
|
|
def wrap_with_transforms(f, transforms, *args):
|
|
new_args = tuple(
|
|
state_types.TransformedRef(a, t) if t else a
|
|
for a, t in zip(args, transforms)
|
|
)
|
|
return f(*new_args)
|
|
|
|
|
|
run_scoped_p = jax_core.Primitive("run_scoped")
|
|
run_scoped_p.multiple_results = True
|
|
|
|
|
|
def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
|
|
"""Calls the function with allocated references and returns the result.
|
|
|
|
The positional and keyword arguments describe which reference types
|
|
to allocate for each argument. Each backend has its own set of reference
|
|
types in addition to :class:`jax.experimental.pallas.MemoryRef`.
|
|
"""
|
|
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
|
|
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
|
# We allow ref avals to be transformed references.
|
|
ref_avals = [t.get_ref_aval() for t in flat_types]
|
|
avals = [
|
|
t.ref if isinstance(t, state_types.TransformedRef) else t
|
|
for t in ref_avals
|
|
]
|
|
ref_transforms = tuple(
|
|
t.transforms if isinstance(t, state_types.TransformedRef) else ()
|
|
for t in ref_avals
|
|
)
|
|
flat_fun = wrap_with_transforms(flat_fun, ref_transforms)
|
|
# Turn the function into a jaxpr. The body of run_scoped may have
|
|
# effects (IO) on constvars (i.e. variables inherited from the
|
|
# parent scope). Jax can't reason about effects to references that
|
|
# are not in the invars of an operation so we just put them all
|
|
# there.
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
|
|
out = run_scoped_p.bind(*consts, jaxpr=jaxpr)
|
|
return tree_util.tree_unflatten(out_tree_thunk(), out)
|
|
|
|
|
|
@run_scoped_p.def_effectful_abstract_eval
|
|
def _run_scoped_abstract_eval(*args, jaxpr):
|
|
del args
|
|
# jaxpr will have effects for its inputs (Refs that are allocated) and for
|
|
# constvars (closed over Refs). The effects for the allocated Refs are local
|
|
# to the jaxpr and shouldn't propagate out.
|
|
nonlocal_effects = {
|
|
eff
|
|
for eff in jaxpr.effects
|
|
if not (
|
|
isinstance(eff, effects.JaxprInputEffect)
|
|
and eff.input_index >= len(jaxpr.constvars)
|
|
)
|
|
}
|
|
return [v.aval for v in jaxpr.outvars], nonlocal_effects
|
|
|
|
|
|
def _run_scoped_discharge_rule(
|
|
should_discharge,
|
|
in_avals,
|
|
out_avals,
|
|
*args_flat,
|
|
jaxpr,
|
|
**_):
|
|
del out_avals
|
|
num_consts = len(args_flat)
|
|
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
|
|
num_return_values = len(jaxpr_noconst.outvars)
|
|
should_discharge = should_discharge + [
|
|
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
|
|
]
|
|
discharged_body, new_consts = state_discharge.discharge_state(
|
|
jaxpr_noconst, [], should_discharge=should_discharge)
|
|
if new_consts:
|
|
raise NotImplementedError(
|
|
"Cannot handle new consts created by state discharge.")
|
|
# Create inputs filled with uninitialized values to the body.
|
|
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
|
|
init_vals = [uninitialized_value(
|
|
aval.shape, aval.dtype) for aval in body_avals]
|
|
init_vals_with_consts = args_flat + tuple(init_vals)
|
|
out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts)
|
|
# Order of outputs:
|
|
# (1) return values, (2) closed refs, (3) scoped refs.
|
|
return_values = out[:num_return_values]
|
|
ref_outputs = out[num_return_values:]
|
|
# We update all ref values with their updated values from the discharged
|
|
# body. For other values we leave them in place.
|
|
updates = [
|
|
ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef)
|
|
else None for aval in in_avals]
|
|
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
|
|
return updates, return_values
|
|
|
|
|
|
state_discharge.register_partial_discharge_rule(run_scoped_p)(
|
|
_run_scoped_discharge_rule)
|
|
|
|
|
|
@functools.partial(mlir.register_lowering, run_scoped_p)
|
|
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
|
|
# This lowering rule gets triggered when run_scoped is not discharged.
|
|
# In this case there are no stateful effects to handle.
|
|
should_discharge = [
|
|
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
|
|
]
|
|
|
|
def _lower_fun(*lower_fun_args):
|
|
updates, out = _run_scoped_discharge_rule(
|
|
should_discharge,
|
|
[], [], *lower_fun_args,
|
|
jaxpr=jaxpr)
|
|
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
|
|
return out
|
|
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)
|