Documented a few more Pallas APIs and added them to the API docs

This commit is contained in:
Sergei Lebedev 2024-07-19 14:10:37 +01:00
parent d8f435094d
commit 4fa93cff35
2 changed files with 192 additions and 42 deletions

View File

@ -10,6 +10,7 @@ Classes
:toctree: _autosummary
BlockSpec
Slice
Functions
---------
@ -21,3 +22,16 @@ Functions
program_id
num_programs
load
store
swap
atomic_and
atomic_add
atomic_cas
atomic_max
atomic_min
atomic_or
atomic_xchg
debug_print

View File

@ -60,20 +60,20 @@ def program_id(axis: int) -> jax.Array:
"""
return program_id_p.bind(axis=axis)
@program_id_p.def_custom_bind
def program_id_bind(*, axis: int):
grid_env = pallas_core.current_grid_env()
if grid_env:
return grid_env[axis].index
frame = pallas_core.axis_frame()
# Query the size of the axis to make sure its a valid axis (and error
# Query the size of the axis to make sure it's a valid axis (and error
# otherwise).
_ = frame.size(axis)
return jax_core.Primitive.bind(program_id_p, axis=axis)
program_id_p.def_custom_bind(program_id_bind)
@program_id_p.def_abstract_eval
def _program_id_abstract_eval(**_):
return jax_core.ShapedArray((), jnp.int32)
program_id_p.def_abstract_eval(_program_id_abstract_eval)
num_programs_p = jax_core.Primitive("num_programs")
@ -154,6 +154,7 @@ def _atomic_rmw_discharge_rule(
state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule)
@atomic_rmw_p.def_effectful_abstract_eval
def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
ref, _, _, _ = args_tree.unflatten(avals_flat)
if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
@ -170,41 +171,152 @@ def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
return _swap_abstract_eval(*avals_flat, args_tree=args_tree)
atomic_rmw_p.def_effectful_abstract_eval(_atomic_abstract_eval)
def atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
atomic_type: AtomicOpType):
def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
atomic_type: AtomicOpType):
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "atomic_rmw")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
return atomic_rmw_p.bind(
*args_flat, args_tree=args_tree, atomic_type=atomic_type
)
atomic_xchg = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XCHG)
atomic_add = functools.partial(atomic_rmw, atomic_type=AtomicOpType.ADD)
atomic_max = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MAX)
atomic_min = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MIN)
atomic_and = functools.partial(atomic_rmw, atomic_type=AtomicOpType.AND)
atomic_or = functools.partial(atomic_rmw, atomic_type=AtomicOpType.OR)
atomic_xor = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XOR)
def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically exchanges the given value with the value at the given index.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the aupdate.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG
)
def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] += val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD
)
def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX
)
def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN
)
def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] &= val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND
)
def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] |= val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR
)
def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None):
"""Atomically computes ``x_ref_or_view[idx] ^= val``.
Args:
x_ref_or_view: The ref to operate on.
idx: The indexer to use.
mask: TO BE DOCUMENTED.
Returns:
The value at the given index prior to the atomic operation.
"""
return _atomic_rmw(
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR
)
atomic_cas_p = jax_core.Primitive("atomic_cas")
@atomic_cas_p.def_effectful_abstract_eval
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
if cmp_aval.dtype != val_aval.dtype:
raise ValueError("Dtypes in cmp/val need to match")
if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape:
raise ValueError("cmp and val must have identical dtypes and shapes")
if ref_aval.shape:
raise ValueError("Ref must be scalar.")
raise ValueError("ref must be scalar.")
if cmp_aval.shape:
raise ValueError("Cmp must be scalar.")
raise ValueError("cmp must be scalar.")
if val_aval.shape:
raise ValueError("Val must be scalar.")
if cmp_aval.shape != val_aval.shape:
raise ValueError("Dtypes in cmp/val need to match")
raise ValueError("val must be scalar.")
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)}
atomic_cas_p.def_effectful_abstract_eval(_atomic_cas_abstract_eval)
def atomic_cas(ref, cmp, val):
"""Performs an atomic compare-and-swap of the value in the ref with the
given value.
Args:
ref: The ref to operate on.
cmp: The expected value to compare against.
val: The value to swap in.
Returns:
The value at the given index prior to the atomic operation.
"""
return atomic_cas_p.bind(ref, cmp, val)
@state_discharge.register_discharge_rule(atomic_cas_p)
@ -223,9 +335,9 @@ def max_contiguous(x, values):
values = [values]
return max_contiguous_p.bind(x, values=values)
@max_contiguous_p.def_abstract_eval
def _max_contiguous_abstract_eval(aval, **_):
return aval
max_contiguous_p.def_abstract_eval(_max_contiguous_abstract_eval)
multiple_of_p = jax_core.Primitive("multiple_of")
@ -237,13 +349,14 @@ def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array:
values = [values]
return multiple_of_p.bind(x, values=values)
@multiple_of_p.def_abstract_eval
def _multiple_of_abstract_eval(aval, **_):
return aval
multiple_of_p.def_abstract_eval(_multiple_of_abstract_eval)
load_p = jax_core.Primitive('masked_load')
@load_p.def_effectful_abstract_eval
def _load_abstract_eval(*avals_flat, args_tree, **_):
ref, indexers, _, _ = args_tree.unflatten(avals_flat)
return (
@ -252,8 +365,6 @@ def _load_abstract_eval(*avals_flat, args_tree, **_):
)
load_p.def_effectful_abstract_eval(_load_abstract_eval)
def _load_pp_rule(eqn, context, settings):
# Pretty prints `a = load x i` as `x[i] <- a`
y, = eqn.outvars
@ -339,6 +450,8 @@ def _pad_values_to_avoid_dynamic_slice_oob_shift(value,
_unpad_values_to_avoid_dynamic_slice_oob_shift = partial(
_pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True)
@state_discharge.register_discharge_rule(load_p)
def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
del out_avals # Unused.
ref, indexers, mask, other = args_tree.unflatten(args_flat)
@ -371,11 +484,10 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
return (None,) * len(in_avals), out
state_discharge.register_discharge_rule(load_p)(_load_discharge_rule)
swap_p = jax_core.Primitive('masked_swap')
@swap_p.def_effectful_abstract_eval
def _swap_abstract_eval(*avals_flat, args_tree, **_):
ref, indexers, val, _ = args_tree.unflatten(avals_flat)
expected_output_shape = indexers[-1].get_indexer_shape()
@ -395,8 +507,6 @@ def _swap_abstract_eval(*avals_flat, args_tree, **_):
)
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _swap_pp_rule(eqn, context, settings):
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
# or:
@ -449,6 +559,7 @@ def _swap_jvp(primals, tangents, *, args_tree, **params):
ad.primitive_jvps[swap_p] = _swap_jvp
@state_discharge.register_discharge_rule(swap_p)
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
del out_avals # Unused.
ref, indexers, val, mask = args_tree.unflatten(args_flat)
@ -493,11 +604,24 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
return (x_new,) + (None,) * (len(in_avals) - 1), out
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
eviction_policy=None, volatile=False) -> jax.Array:
"""Returns an array loaded from the given index.
If neither ``mask`` nor ``other`` is specified, this function has the same
semantics as ``x_ref_or_view[idx]`` in JAX.
Args:
x_ref_or_view: The ref to load from.
idx: The indexer to use.
mask: An optional boolean mask specifying which indices to load.
If mask is ``False`` and ``other`` is not given, no assumptions can
be made about the value in the resulting array.
other: An optional value to use for indices where mask is ``False``.
cache_modifier: TO BE DOCUMENTED.
eviction_policy: TO BE DOCUMENTED.
volatile: TO BE DOCUMENTED.
"""
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
return load_p.bind(
@ -509,7 +633,14 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
)
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
_function_name="swap") -> Any:
_function_name="swap") -> jax.Array:
"""Swaps the value at the given index and returns the old value.
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
Returns:
The value stored in the ref prior to the swap.
"""
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
return swap_p.bind(
@ -517,6 +648,10 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
)
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
"""Stores a value at the given index.
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
"""
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
_function_name="store")
@ -562,14 +697,15 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
Args:
fmt: A format string to be included in the output. The restrictions on the
format string depend on the backend:
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
(``{...}``), since it is always printed before any of the values.
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string.
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
(``{...}``), since it is always printed before any of the values.
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string.
*args: The scalar values to print.
""" # fmt: skip
has_placeholders = False