mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Documented a few more Pallas APIs and added them to the API docs
This commit is contained in:
parent
d8f435094d
commit
4fa93cff35
@ -10,6 +10,7 @@ Classes
|
||||
:toctree: _autosummary
|
||||
|
||||
BlockSpec
|
||||
Slice
|
||||
|
||||
Functions
|
||||
---------
|
||||
@ -21,3 +22,16 @@ Functions
|
||||
program_id
|
||||
num_programs
|
||||
|
||||
load
|
||||
store
|
||||
swap
|
||||
|
||||
atomic_and
|
||||
atomic_add
|
||||
atomic_cas
|
||||
atomic_max
|
||||
atomic_min
|
||||
atomic_or
|
||||
atomic_xchg
|
||||
|
||||
debug_print
|
@ -60,20 +60,20 @@ def program_id(axis: int) -> jax.Array:
|
||||
"""
|
||||
return program_id_p.bind(axis=axis)
|
||||
|
||||
@program_id_p.def_custom_bind
|
||||
def program_id_bind(*, axis: int):
|
||||
grid_env = pallas_core.current_grid_env()
|
||||
if grid_env:
|
||||
return grid_env[axis].index
|
||||
frame = pallas_core.axis_frame()
|
||||
# Query the size of the axis to make sure its a valid axis (and error
|
||||
# Query the size of the axis to make sure it's a valid axis (and error
|
||||
# otherwise).
|
||||
_ = frame.size(axis)
|
||||
return jax_core.Primitive.bind(program_id_p, axis=axis)
|
||||
program_id_p.def_custom_bind(program_id_bind)
|
||||
|
||||
@program_id_p.def_abstract_eval
|
||||
def _program_id_abstract_eval(**_):
|
||||
return jax_core.ShapedArray((), jnp.int32)
|
||||
program_id_p.def_abstract_eval(_program_id_abstract_eval)
|
||||
|
||||
num_programs_p = jax_core.Primitive("num_programs")
|
||||
|
||||
@ -154,6 +154,7 @@ def _atomic_rmw_discharge_rule(
|
||||
state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule)
|
||||
|
||||
|
||||
@atomic_rmw_p.def_effectful_abstract_eval
|
||||
def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
|
||||
ref, _, _, _ = args_tree.unflatten(avals_flat)
|
||||
if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
|
||||
@ -170,41 +171,152 @@ def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
|
||||
return _swap_abstract_eval(*avals_flat, args_tree=args_tree)
|
||||
|
||||
|
||||
atomic_rmw_p.def_effectful_abstract_eval(_atomic_abstract_eval)
|
||||
|
||||
def atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
|
||||
atomic_type: AtomicOpType):
|
||||
def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
|
||||
atomic_type: AtomicOpType):
|
||||
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "atomic_rmw")
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
|
||||
return atomic_rmw_p.bind(
|
||||
*args_flat, args_tree=args_tree, atomic_type=atomic_type
|
||||
)
|
||||
|
||||
atomic_xchg = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XCHG)
|
||||
atomic_add = functools.partial(atomic_rmw, atomic_type=AtomicOpType.ADD)
|
||||
atomic_max = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MAX)
|
||||
atomic_min = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MIN)
|
||||
atomic_and = functools.partial(atomic_rmw, atomic_type=AtomicOpType.AND)
|
||||
atomic_or = functools.partial(atomic_rmw, atomic_type=AtomicOpType.OR)
|
||||
atomic_xor = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XOR)
|
||||
def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically exchanges the given value with the value at the given index.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the aupdate.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG
|
||||
)
|
||||
|
||||
|
||||
def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] += val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD
|
||||
)
|
||||
|
||||
|
||||
def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX
|
||||
)
|
||||
|
||||
|
||||
def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN
|
||||
)
|
||||
|
||||
|
||||
def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] &= val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND
|
||||
)
|
||||
|
||||
|
||||
def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] |= val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR
|
||||
)
|
||||
|
||||
|
||||
def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] ^= val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR
|
||||
)
|
||||
|
||||
atomic_cas_p = jax_core.Primitive("atomic_cas")
|
||||
|
||||
@atomic_cas_p.def_effectful_abstract_eval
|
||||
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
|
||||
if cmp_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Dtypes in cmp/val need to match")
|
||||
if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape:
|
||||
raise ValueError("cmp and val must have identical dtypes and shapes")
|
||||
if ref_aval.shape:
|
||||
raise ValueError("Ref must be scalar.")
|
||||
raise ValueError("ref must be scalar.")
|
||||
if cmp_aval.shape:
|
||||
raise ValueError("Cmp must be scalar.")
|
||||
raise ValueError("cmp must be scalar.")
|
||||
if val_aval.shape:
|
||||
raise ValueError("Val must be scalar.")
|
||||
if cmp_aval.shape != val_aval.shape:
|
||||
raise ValueError("Dtypes in cmp/val need to match")
|
||||
raise ValueError("val must be scalar.")
|
||||
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)}
|
||||
atomic_cas_p.def_effectful_abstract_eval(_atomic_cas_abstract_eval)
|
||||
|
||||
|
||||
def atomic_cas(ref, cmp, val):
|
||||
"""Performs an atomic compare-and-swap of the value in the ref with the
|
||||
given value.
|
||||
|
||||
Args:
|
||||
ref: The ref to operate on.
|
||||
cmp: The expected value to compare against.
|
||||
val: The value to swap in.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return atomic_cas_p.bind(ref, cmp, val)
|
||||
|
||||
@state_discharge.register_discharge_rule(atomic_cas_p)
|
||||
@ -223,9 +335,9 @@ def max_contiguous(x, values):
|
||||
values = [values]
|
||||
return max_contiguous_p.bind(x, values=values)
|
||||
|
||||
@max_contiguous_p.def_abstract_eval
|
||||
def _max_contiguous_abstract_eval(aval, **_):
|
||||
return aval
|
||||
max_contiguous_p.def_abstract_eval(_max_contiguous_abstract_eval)
|
||||
|
||||
multiple_of_p = jax_core.Primitive("multiple_of")
|
||||
|
||||
@ -237,13 +349,14 @@ def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array:
|
||||
values = [values]
|
||||
return multiple_of_p.bind(x, values=values)
|
||||
|
||||
@multiple_of_p.def_abstract_eval
|
||||
def _multiple_of_abstract_eval(aval, **_):
|
||||
return aval
|
||||
multiple_of_p.def_abstract_eval(_multiple_of_abstract_eval)
|
||||
|
||||
load_p = jax_core.Primitive('masked_load')
|
||||
|
||||
|
||||
@load_p.def_effectful_abstract_eval
|
||||
def _load_abstract_eval(*avals_flat, args_tree, **_):
|
||||
ref, indexers, _, _ = args_tree.unflatten(avals_flat)
|
||||
return (
|
||||
@ -252,8 +365,6 @@ def _load_abstract_eval(*avals_flat, args_tree, **_):
|
||||
)
|
||||
|
||||
|
||||
load_p.def_effectful_abstract_eval(_load_abstract_eval)
|
||||
|
||||
def _load_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = load x i` as `x[i] <- a`
|
||||
y, = eqn.outvars
|
||||
@ -339,6 +450,8 @@ def _pad_values_to_avoid_dynamic_slice_oob_shift(value,
|
||||
_unpad_values_to_avoid_dynamic_slice_oob_shift = partial(
|
||||
_pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True)
|
||||
|
||||
|
||||
@state_discharge.register_discharge_rule(load_p)
|
||||
def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
del out_avals # Unused.
|
||||
ref, indexers, mask, other = args_tree.unflatten(args_flat)
|
||||
@ -371,11 +484,10 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
return (None,) * len(in_avals), out
|
||||
|
||||
|
||||
state_discharge.register_discharge_rule(load_p)(_load_discharge_rule)
|
||||
|
||||
swap_p = jax_core.Primitive('masked_swap')
|
||||
|
||||
|
||||
@swap_p.def_effectful_abstract_eval
|
||||
def _swap_abstract_eval(*avals_flat, args_tree, **_):
|
||||
ref, indexers, val, _ = args_tree.unflatten(avals_flat)
|
||||
expected_output_shape = indexers[-1].get_indexer_shape()
|
||||
@ -395,8 +507,6 @@ def _swap_abstract_eval(*avals_flat, args_tree, **_):
|
||||
)
|
||||
|
||||
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
def _swap_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
|
||||
# or:
|
||||
@ -449,6 +559,7 @@ def _swap_jvp(primals, tangents, *, args_tree, **params):
|
||||
ad.primitive_jvps[swap_p] = _swap_jvp
|
||||
|
||||
|
||||
@state_discharge.register_discharge_rule(swap_p)
|
||||
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
del out_avals # Unused.
|
||||
ref, indexers, val, mask = args_tree.unflatten(args_flat)
|
||||
@ -493,11 +604,24 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
return (x_new,) + (None,) * (len(in_avals) - 1), out
|
||||
|
||||
|
||||
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
|
||||
|
||||
|
||||
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
|
||||
eviction_policy=None, volatile=False) -> jax.Array:
|
||||
"""Returns an array loaded from the given index.
|
||||
|
||||
If neither ``mask`` nor ``other`` is specified, this function has the same
|
||||
semantics as ``x_ref_or_view[idx]`` in JAX.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to load from.
|
||||
idx: The indexer to use.
|
||||
mask: An optional boolean mask specifying which indices to load.
|
||||
If mask is ``False`` and ``other`` is not given, no assumptions can
|
||||
be made about the value in the resulting array.
|
||||
other: An optional value to use for indices where mask is ``False``.
|
||||
cache_modifier: TO BE DOCUMENTED.
|
||||
eviction_policy: TO BE DOCUMENTED.
|
||||
volatile: TO BE DOCUMENTED.
|
||||
"""
|
||||
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
|
||||
return load_p.bind(
|
||||
@ -509,7 +633,14 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
|
||||
)
|
||||
|
||||
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
|
||||
_function_name="swap") -> Any:
|
||||
_function_name="swap") -> jax.Array:
|
||||
"""Swaps the value at the given index and returns the old value.
|
||||
|
||||
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
||||
|
||||
Returns:
|
||||
The value stored in the ref prior to the swap.
|
||||
"""
|
||||
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
|
||||
return swap_p.bind(
|
||||
@ -517,6 +648,10 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
|
||||
)
|
||||
|
||||
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
|
||||
"""Stores a value at the given index.
|
||||
|
||||
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
||||
"""
|
||||
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
|
||||
_function_name="store")
|
||||
|
||||
@ -562,14 +697,15 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
|
||||
Args:
|
||||
fmt: A format string to be included in the output. The restrictions on the
|
||||
format string depend on the backend:
|
||||
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
|
||||
(``{...}``), since it is always printed before any of the values.
|
||||
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
|
||||
contain a placeholder for each value to be printed. Format specs and
|
||||
conversions are not supported.
|
||||
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
|
||||
integers. If there are no placeholders, the values are printed after
|
||||
the format string.
|
||||
|
||||
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
|
||||
(``{...}``), since it is always printed before any of the values.
|
||||
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
|
||||
contain a placeholder for each value to be printed. Format specs and
|
||||
conversions are not supported.
|
||||
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
|
||||
integers. If there are no placeholders, the values are printed after
|
||||
the format string.
|
||||
*args: The scalar values to print.
|
||||
""" # fmt: skip
|
||||
has_placeholders = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user