mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
1551 lines
47 KiB
Python
1551 lines
47 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for pull block specs through a fusion."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import threading
|
|
from typing import Any, Callable, Protocol, Sequence
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import ad_util
|
|
from jax._src import core
|
|
from jax._src import custom_derivatives
|
|
from jax._src import pjit
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.pallas.fuser import fuser_utils
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
# TODO(sharadmv): Enable type checking.
|
|
# mypy: ignore-errors
|
|
|
|
pull_block_spec_rules: dict[core.Primitive, PullBlockSpecRuleFn] = {}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PullRuleContext:
|
|
avals_in: tuple[core.AbstractValue, ...]
|
|
avals_out: tuple[core.AbstractValue, ...]
|
|
out_usages: tuple[set[Usage], ...]
|
|
eval_function: Any = dataclasses.field(default=None, init=False)
|
|
scalar_prefetch_fn: Any = dataclasses.field(default=None, init=False)
|
|
scalar_prefetch_handler: Any | None
|
|
grid: tuple[int | jax.Array, ...] | None
|
|
|
|
def __post_init__(self):
|
|
self._scalar_prefetch = None
|
|
|
|
def set_eval_function(self, eval_function):
|
|
self.eval_function = eval_function
|
|
return eval_function
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PushRuleContext:
|
|
avals_in: tuple[core.AbstractValue, ...]
|
|
avals_out: tuple[core.AbstractValue, ...]
|
|
|
|
|
|
def make_scalar_prefetch_handler(*args):
|
|
def scalar_prefetch_getter(*sp_inputs):
|
|
result = sp_inputs
|
|
for i in args:
|
|
result = result[i]
|
|
return result
|
|
|
|
return scalar_prefetch_getter
|
|
|
|
|
|
def _default_eval_fn(eqn, eval_ctx, *args):
|
|
del eval_ctx
|
|
out = eqn.primitive.bind(*args, **eqn.params)
|
|
if eqn.primitive.multiple_results:
|
|
return out
|
|
return [out]
|
|
|
|
|
|
def _wrap_eval_fn(primitive, eval_fn):
|
|
def wrapped(*args):
|
|
if primitive.multiple_results:
|
|
return eval_fn(*args)
|
|
return [eval_fn(*args)]
|
|
|
|
return wrapped
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class UsageRuleContext:
|
|
avals_in: tuple[core.AbstractValue, ...]
|
|
avals_out: tuple[core.AbstractValue, ...]
|
|
|
|
|
|
def compute_usage(jaxpr: core.Jaxpr, jaxpr_out_usages):
|
|
# TODO(sharadmv): maybe simplify this by only identifying scalar prefetch
|
|
# and then DCEing it out.
|
|
usage_env: dict[core.Atom, set[Usage]] = {}
|
|
|
|
def read_usage_env(atom: core.Atom) -> set[Usage]:
|
|
assert not isinstance(atom, core.Literal)
|
|
return usage_env.get(atom, {Usage.REGULAR})
|
|
|
|
def write_usage_env(atom: core.Atom, usage: set[Usage]):
|
|
if isinstance(atom, core.Literal):
|
|
return
|
|
if atom not in usage_env:
|
|
usage_env[atom] = set()
|
|
usage_env[atom] |= usage
|
|
|
|
util.safe_map(write_usage_env, jaxpr.outvars, jaxpr_out_usages)
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
out_usages = util.safe_map(read_usage_env, eqn.outvars)
|
|
rule = usage_rules.get(eqn.primitive, None)
|
|
if rule:
|
|
usage_rule_ctx = UsageRuleContext(
|
|
avals_in=tuple(v.aval for v in eqn.invars),
|
|
avals_out=tuple(v.aval for v in eqn.outvars),
|
|
)
|
|
if eqn.primitive.multiple_results:
|
|
in_usages = rule(usage_rule_ctx, out_usages, **eqn.params)
|
|
else:
|
|
in_usages = rule(usage_rule_ctx, out_usages[0], **eqn.params)
|
|
else:
|
|
# Usages are forwarded
|
|
all_usages = set.union(*out_usages)
|
|
in_usages = [all_usages] * len(eqn.invars)
|
|
util.safe_map(write_usage_env, eqn.invars, in_usages)
|
|
return read_usage_env
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class KernelEvalContext:
|
|
scalar_prefetch: Any | None
|
|
program_ids: tuple[int | jax.Array, ...] | None
|
|
avals_in: tuple[core.AbstractValue, ...] | None
|
|
avals_out: tuple[core.AbstractValue, ...] | None
|
|
in_block_specs: tuple[pallas_core.BlockSpec, ...] | None
|
|
out_block_specs: tuple[pallas_core.BlockSpec, ...] | None
|
|
grid: tuple[int | jax.Array, ...] | None
|
|
scalar_prefetch_handler: Any | None
|
|
out_usages: tuple[set[Usage], ...] | None
|
|
|
|
def get_program_ids(self):
|
|
if self.program_ids is None:
|
|
raise ValueError('Program ids not available.')
|
|
return self.program_ids
|
|
|
|
def get_in_block_indices(self):
|
|
with _sp_context(*self.scalar_prefetch):
|
|
return tuple(
|
|
bs.index_map(*self.program_ids) for bs in self.in_block_specs
|
|
)
|
|
|
|
def get_out_block_indices(self):
|
|
with _sp_context(*self.scalar_prefetch):
|
|
return tuple(
|
|
bs.index_map(*self.program_ids) for bs in self.out_block_specs
|
|
)
|
|
|
|
|
|
_illegal = object()
|
|
|
|
_sp_env = threading.local()
|
|
_sp_env.scalar_prefetch = None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _sp_context(*scalar_prefetch):
|
|
assert _sp_env.scalar_prefetch is None, scalar_prefetch
|
|
_sp_env.scalar_prefetch = scalar_prefetch
|
|
try:
|
|
yield
|
|
finally:
|
|
_sp_env.scalar_prefetch = None
|
|
|
|
|
|
def _get_scalar_prefetch():
|
|
return _sp_env.scalar_prefetch
|
|
|
|
|
|
def _wrap_block_spec_scalar_prefetch(
|
|
block_spec: pallas_core.BlockSpec,
|
|
num_grid_args: int,
|
|
) -> pallas_core.BlockSpec:
|
|
if block_spec is pallas_core.no_block_spec:
|
|
return block_spec
|
|
|
|
def new_index_map(*args_and_scalar_prefetch):
|
|
args, scalar_prefetch = util.split_list(
|
|
args_and_scalar_prefetch,
|
|
[num_grid_args],
|
|
)
|
|
with _sp_context(*scalar_prefetch):
|
|
return block_spec.index_map(*args)
|
|
|
|
return block_spec.replace(index_map=new_index_map)
|
|
|
|
|
|
_unwrap_cache: dict[int, pallas_core.BlockSpec] = {}
|
|
|
|
|
|
def _unwrap_block_spec_scalar_prefetch(
|
|
block_spec: pallas_core.BlockSpec,
|
|
) -> pallas_core.BlockSpec:
|
|
if id(block_spec) in _unwrap_cache:
|
|
return _unwrap_cache[id(block_spec)]
|
|
|
|
def new_index_map(*args):
|
|
scalar_prefetch = _get_scalar_prefetch()
|
|
assert scalar_prefetch is not None
|
|
return block_spec.index_map(*args, *scalar_prefetch)
|
|
|
|
out_block_spec = block_spec.replace(index_map=new_index_map)
|
|
_unwrap_cache[id(block_spec)] = out_block_spec
|
|
return out_block_spec
|
|
|
|
|
|
def pull_block_spec(
|
|
f: Callable,
|
|
out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...],
|
|
*,
|
|
scalar_prefetch_handler: Any | None = None,
|
|
grid: tuple[int | jax.Array, ...] | None = None,
|
|
):
|
|
def wrapped(*args, **kwargs):
|
|
jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr(
|
|
f, *args, **kwargs
|
|
)
|
|
# TODO(sharadmv): handle these consts better, they should correspond to
|
|
# scalar prefetch.
|
|
del consts, out_tree_
|
|
jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars)
|
|
block_specs_ = jax.tree.map(
|
|
_unwrap_block_spec_scalar_prefetch, out_block_specs
|
|
)
|
|
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
|
|
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
|
jaxpr,
|
|
used_outputs=[True] * len(jaxpr.outvars),
|
|
instantiate=True,
|
|
)
|
|
assert all(used_invars)
|
|
assert all(used_consts)
|
|
in_block_specs, env, read_usage_env = _pull_block_spec(
|
|
jaxpr,
|
|
tuple(flat_block_specs),
|
|
jaxpr_out_usages,
|
|
scalar_prefetch_handler=scalar_prefetch_handler,
|
|
grid=grid,
|
|
)
|
|
kernel_fn = make_kernel_function(
|
|
jaxpr,
|
|
in_tree,
|
|
out_tree,
|
|
read_usage_env,
|
|
in_block_specs,
|
|
env,
|
|
scalar_prefetch_handler,
|
|
grid,
|
|
)
|
|
in_block_specs = jax.tree.unflatten(in_tree, in_block_specs)
|
|
in_block_specs = jax.tree.map(
|
|
functools.partial(
|
|
_wrap_block_spec_scalar_prefetch,
|
|
num_grid_args=len(grid),
|
|
),
|
|
in_block_specs,
|
|
)
|
|
in_block_arg_specs, in_block_kwarg_specs = in_block_specs
|
|
return kernel_fn, in_block_arg_specs, in_block_kwarg_specs
|
|
|
|
return wrapped
|
|
|
|
|
|
def _pull_block_spec(
|
|
jaxpr: core.Jaxpr,
|
|
out_block_specs: tuple[pallas_core.BlockSpec, ...],
|
|
out_usages,
|
|
*,
|
|
scalar_prefetch_handler: Any | None = None,
|
|
grid: tuple[int | jax.Array, ...],
|
|
) -> tuple[
|
|
tuple[pallas_core.BlockSpec | pallas_core.NoBlockSpec, ...],
|
|
tuple[dict[core.Var, pallas_core.BlockSpec], dict[int, Any]],
|
|
Any,
|
|
]:
|
|
read_usage_env = compute_usage(jaxpr, out_usages)
|
|
jaxpr_invar_usages = util.safe_map(read_usage_env, jaxpr.invars)
|
|
env: dict[core.Var, pallas_core.BlockSpec] = {}
|
|
scalar_prefetch_fn_env = {}
|
|
|
|
for outvar, bs in zip(jaxpr.outvars, out_block_specs, strict=True):
|
|
assert isinstance(outvar, core.Var)
|
|
env[outvar] = bs
|
|
|
|
def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any:
|
|
if isinstance(atom, core.Literal):
|
|
return pallas_core.no_block_spec
|
|
return env[atom]
|
|
|
|
def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec):
|
|
if isinstance(atom, core.Literal):
|
|
return
|
|
env[atom] = block_spec
|
|
|
|
for i, eqn in reversed(list(enumerate(jaxpr.eqns))):
|
|
eqn_out_block_specs = tuple(util.safe_map(_read_block_spec, eqn.outvars))
|
|
rule = pull_block_spec_rules.get(eqn.primitive, None)
|
|
if not rule:
|
|
raise NotImplementedError(eqn.primitive)
|
|
ctx = PullRuleContext(
|
|
avals_in=tuple(v.aval for v in eqn.invars),
|
|
avals_out=tuple(v.aval for v in eqn.outvars),
|
|
out_usages=tuple(read_usage_env(v) for v in jaxpr.outvars),
|
|
scalar_prefetch_handler=scalar_prefetch_handler,
|
|
grid=grid,
|
|
)
|
|
if eqn.primitive.multiple_results:
|
|
in_block_specs = rule(ctx, eqn_out_block_specs, **eqn.params)
|
|
else:
|
|
in_block_specs = rule(ctx, eqn_out_block_specs[0], **eqn.params)
|
|
|
|
eqn_invar_usages = [
|
|
read_usage_env(v) if not isinstance(v, core.Literal) else set()
|
|
for v in eqn.invars
|
|
]
|
|
if any(Usage.SCALAR_PREFETCH in usage for usage in eqn_invar_usages):
|
|
scalar_prefetch_vars = [
|
|
Usage.SCALAR_PREFETCH in usage for usage in eqn_invar_usages
|
|
]
|
|
needed_invars = [
|
|
v
|
|
for v, sp in zip(eqn.invars, scalar_prefetch_vars)
|
|
if sp or isinstance(v, core.Literal)
|
|
]
|
|
scalar_prefetch_jaxpr_no_dce = core.Jaxpr(
|
|
jaxpr.constvars,
|
|
jaxpr.invars,
|
|
needed_invars,
|
|
jaxpr.eqns[: jaxpr.eqns.index(eqn)],
|
|
debug_info=jaxpr.debug_info,
|
|
)
|
|
scalar_prefetch_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
|
scalar_prefetch_jaxpr_no_dce,
|
|
[True] * len(scalar_prefetch_jaxpr_no_dce.outvars),
|
|
)
|
|
assert not any(used_invars)
|
|
scalar_prefetch_jaxpr = scalar_prefetch_jaxpr.replace(
|
|
constvars=[],
|
|
invars=jaxpr.constvars,
|
|
)
|
|
|
|
def _scalar_prefetch_fn(jaxpr):
|
|
if grid is None:
|
|
raise ValueError('Grid must be provided to pull_block_spec.')
|
|
args = scalar_prefetch_handler(*_get_scalar_prefetch())
|
|
# Load from SMEM
|
|
args = [a[0] for a in args]
|
|
return core.eval_jaxpr(jaxpr, [], *args)
|
|
|
|
scalar_prefetch_fn = functools.partial(
|
|
_scalar_prefetch_fn, scalar_prefetch_jaxpr
|
|
)
|
|
ctx.scalar_prefetch_fn = scalar_prefetch_fn_env[i] = scalar_prefetch_fn
|
|
for v, in_block_spec in zip(eqn.invars, in_block_specs, strict=True):
|
|
if (
|
|
not isinstance(v, core.Literal)
|
|
and v in env
|
|
and (bs := env[v]) != in_block_spec
|
|
):
|
|
if bs.block_shape != in_block_spec.block_shape:
|
|
in_block_spec = in_block_spec.replace(block_shape=_illegal)
|
|
in_block_spec = in_block_spec.replace(index_map=_illegal)
|
|
_write_block_spec(v, in_block_spec)
|
|
|
|
def _get_in_block_spec(v, usage):
|
|
if usage == {Usage.SCALAR_PREFETCH}:
|
|
return None
|
|
bs = env.get(v, pallas_core.no_block_spec)
|
|
if bs is not pallas_core.no_block_spec:
|
|
if bs.index_map is _illegal: # pytype: disable=attribute-error
|
|
raise ValueError(f'Found cycle:\n{jaxpr}')
|
|
return bs
|
|
|
|
in_block_specs = tuple(
|
|
_get_in_block_spec(v, usage)
|
|
for v, usage in zip(jaxpr.invars, jaxpr_invar_usages)
|
|
)
|
|
return (
|
|
tuple(in_block_specs),
|
|
(env, scalar_prefetch_fn_env),
|
|
read_usage_env,
|
|
)
|
|
|
|
|
|
def make_kernel_function(
|
|
jaxpr: core.Jaxpr,
|
|
in_tree,
|
|
out_tree,
|
|
read_usage_env,
|
|
in_block_specs,
|
|
block_spec_env,
|
|
scalar_prefetch_handler,
|
|
grid,
|
|
):
|
|
in_avals = [v.aval for v in jaxpr.invars]
|
|
invar_usages = util.safe_map(read_usage_env, jaxpr.invars)
|
|
bs_env, scalar_prefetch_fn_env = block_spec_env
|
|
|
|
def _remove_nones(shape: tuple[int | None, ...] | None) -> tuple[int, ...]:
|
|
assert shape is not None
|
|
return tuple(s for s in shape if s is not None)
|
|
|
|
_no_aval = object()
|
|
|
|
def _get_block_aval(bs, aval):
|
|
if bs is pallas_core.no_block_spec or bs is None:
|
|
return _no_aval
|
|
return aval.update(shape=_remove_nones(bs.block_shape)) # pytype: disable=attribute-error
|
|
|
|
in_block_avals = [
|
|
_get_block_aval(bs, aval)
|
|
for aval, bs in zip(in_avals, in_block_specs, strict=True)
|
|
]
|
|
unflat_in_block_arg_avals, unflat_in_block_kwarg_avals = (
|
|
tree_util.tree_unflatten(in_tree, in_block_avals)
|
|
)
|
|
unflat_arg_usages, unflat_kwarg_usages = tree_util.tree_unflatten(
|
|
in_tree, invar_usages
|
|
)
|
|
|
|
def sds_like(x):
|
|
if x is _no_aval:
|
|
return _no_aval
|
|
return jax.ShapeDtypeStruct(x.shape, x.dtype)
|
|
|
|
kernel_in_type = jax.tree.map(
|
|
sds_like, (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals)
|
|
)
|
|
|
|
def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any:
|
|
if isinstance(atom, core.Literal):
|
|
return pallas_core.no_block_spec
|
|
return bs_env[atom]
|
|
|
|
def kernel_fn(program_ids, scalar_prefetch, *args, **kwargs):
|
|
def _check_args(prefix, path, x, y, usage):
|
|
if usage == {Usage.SCALAR_PREFETCH}:
|
|
return
|
|
if y is _no_aval:
|
|
return
|
|
x_aval, y_aval = core.get_aval(x), core.get_aval(y)
|
|
if x_aval.shape != y_aval.shape:
|
|
raise ValueError(
|
|
f'Shapes do not match: actual={x_aval.shape} !='
|
|
f' expected={y_aval.shape}. Path:'
|
|
f' {prefix}{jax.tree_util.keystr(path)}. Expected type:'
|
|
f' {kernel_in_type}. Actual args: {(args, kwargs)}'
|
|
)
|
|
if x_aval.dtype != y_aval.dtype:
|
|
raise ValueError(
|
|
f'DTypes do not match: actual={x_aval.dtype} !='
|
|
f' expected={y_aval.dtype}. Path:'
|
|
f' {prefix}{jax.tree_util.keystr(path)}. Expected type:'
|
|
f' {kernel_in_type}. Actual args: {(args, kwargs)}'
|
|
)
|
|
|
|
jax.tree_util.tree_map_with_path(
|
|
functools.partial(_check_args, 'args'),
|
|
args,
|
|
kernel_in_type[0],
|
|
unflat_arg_usages,
|
|
)
|
|
jax.tree_util.tree_map_with_path(
|
|
functools.partial(_check_args, 'kwargs'),
|
|
kwargs,
|
|
kernel_in_type[1],
|
|
unflat_kwarg_usages,
|
|
)
|
|
flat_args, in_tree_ = tree_util.tree_flatten((args, kwargs))
|
|
if in_tree_ != tree_util.tree_structure(kernel_in_type):
|
|
raise ValueError(f'Expected {kernel_in_type} PyTree, got {in_tree_}')
|
|
env = {}
|
|
|
|
def read_env(atom):
|
|
match atom:
|
|
case core.Literal():
|
|
return atom.val
|
|
case core.Var():
|
|
return env.get(atom, None)
|
|
|
|
def write_env(var, val):
|
|
env[var] = val
|
|
|
|
for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages):
|
|
if Usage.REGULAR in usage:
|
|
env[invar] = arg
|
|
for i, eqn in enumerate(jaxpr.eqns):
|
|
outvar_usages = [
|
|
read_usage_env(v) if not isinstance(v, core.Literal) else set()
|
|
for v in eqn.outvars
|
|
]
|
|
if any(Usage.REGULAR in usage for usage in outvar_usages):
|
|
in_vals = util.safe_map(read_env, eqn.invars)
|
|
# TODO(sharadmv,justinfu): preserve source mapping
|
|
if not (eval_rule := eval_rules.get(eqn.primitive, None)):
|
|
raise NotImplementedError(eqn.primitive)
|
|
out_block_specs = tuple(util.safe_map(_read_block_spec, eqn.outvars))
|
|
in_block_specs = tuple(util.safe_map(_read_block_spec, eqn.invars))
|
|
out_usages = tuple(read_usage_env(v) for v in eqn.outvars)
|
|
# We need to substitute in scalar prefetch values into equation invars.
|
|
if scalar_prefetch_fn := scalar_prefetch_fn_env.get(i, None):
|
|
# Evaluate scalar prefetch function.
|
|
with _sp_context(*scalar_prefetch):
|
|
scalar_prefetch_vals = scalar_prefetch_fn()
|
|
# Some results of the SP function will be literals.
|
|
eqn_invar_usages = [
|
|
read_usage_env(v)
|
|
if not isinstance(v, core.Literal)
|
|
else {Usage.SCALAR_PREFETCH}
|
|
for v in eqn.invars
|
|
]
|
|
sp_iter = iter(scalar_prefetch_vals)
|
|
for i, usage in enumerate(eqn_invar_usages):
|
|
if usage == {Usage.SCALAR_PREFETCH}:
|
|
if not isinstance(eqn.invars[i], core.Literal):
|
|
in_vals[i] = next(sp_iter)
|
|
eval_ctx = KernelEvalContext(
|
|
avals_in=tuple(v.aval for v in eqn.invars),
|
|
avals_out=tuple(v.aval for v in eqn.outvars),
|
|
scalar_prefetch=scalar_prefetch,
|
|
program_ids=tuple(program_ids),
|
|
in_block_specs=in_block_specs,
|
|
out_block_specs=out_block_specs,
|
|
scalar_prefetch_handler=scalar_prefetch_handler,
|
|
grid=grid,
|
|
out_usages=out_usages,
|
|
)
|
|
outs = eval_rule(eval_ctx, *in_vals, **eqn.params)
|
|
if not eqn.primitive.multiple_results:
|
|
outs = [outs]
|
|
util.safe_map(write_env, eqn.outvars, outs)
|
|
out = util.safe_map(read_env, jaxpr.outvars)
|
|
return tree_util.tree_unflatten(out_tree, out)
|
|
|
|
return kernel_fn
|
|
|
|
|
|
def get_fusion_values(
|
|
fusion: Callable, *args, **kwargs
|
|
) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]:
|
|
jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr(
|
|
fusion, *args, **kwargs
|
|
)
|
|
assert len(values) == len(jaxpr.constvars), (jaxpr, values)
|
|
out_usages = tuple({Usage.REGULAR} for _ in jaxpr.outvars)
|
|
read_usage_env = compute_usage(jaxpr, out_usages)
|
|
constvar_usages = util.safe_map(read_usage_env, jaxpr.constvars)
|
|
invar_usages = util.safe_map(read_usage_env, jaxpr.invars)
|
|
del invar_usages # These don't correspond to values
|
|
# Add leading dimension to scalar prefetch values so Mosaic won't be upset.
|
|
is_scalar_prefetch = tuple(
|
|
Usage.SCALAR_PREFETCH in usage for usage in constvar_usages
|
|
)
|
|
regular_values, scalar_prefetch_values = util.partition_list(
|
|
is_scalar_prefetch, values
|
|
)
|
|
# scalar_prefetch_values = [x for x in scalar_prefetch_values]
|
|
|
|
def new_kernel_fn(values, *args, **kwargs):
|
|
values = util.merge_lists(
|
|
is_scalar_prefetch, values, scalar_prefetch_values
|
|
)
|
|
flat_args, _ = tree_util.tree_flatten((args, kwargs))
|
|
out_flat = core.eval_jaxpr(jaxpr, values, *flat_args)
|
|
return tree_util.tree_unflatten(out_tree, out_flat)
|
|
|
|
return new_kernel_fn, tuple(regular_values), tuple(scalar_prefetch_values)
|
|
|
|
|
|
# # Interpreter rules
|
|
|
|
|
|
# ## Usage interpreter rules
|
|
class Usage(enum.Enum):
|
|
REGULAR = 0
|
|
SCALAR_PREFETCH = 1
|
|
|
|
|
|
class UsageRuleFn(Protocol):
|
|
|
|
def __call__(
|
|
self,
|
|
ctx: UsageRuleContext,
|
|
used_outs: Sequence[set[Usage]] | set[Usage],
|
|
**params: Any,
|
|
) -> Sequence[set[Usage]]:
|
|
...
|
|
|
|
|
|
usage_rules: dict[core.Primitive, UsageRuleFn] = {}
|
|
|
|
|
|
def register_usage_rule(
|
|
prim: core.Primitive,
|
|
) -> Callable[[UsageRuleFn], UsageRuleFn]:
|
|
|
|
def wrapper(
|
|
f: UsageRuleFn,
|
|
) -> UsageRuleFn:
|
|
usage_rules[prim] = f
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
|
|
# ## Eval interpreter rules
|
|
|
|
|
|
class EvalRuleFn(Protocol):
|
|
|
|
def __call__(
|
|
self,
|
|
ctx: KernelEvalContext,
|
|
*args: Any,
|
|
**params: Any,
|
|
) -> Sequence[Any]:
|
|
...
|
|
|
|
|
|
eval_rules: dict[core.Primitive, EvalRuleFn] = {}
|
|
|
|
|
|
def register_eval_rule(
|
|
prim: core.Primitive,
|
|
) -> Callable[[EvalRuleFn], EvalRuleFn]:
|
|
def wrapper(
|
|
f: EvalRuleFn,
|
|
) -> EvalRuleFn:
|
|
eval_rules[prim] = f
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
|
|
# ## Pull block spec interpreter rules
|
|
|
|
|
|
class PullBlockSpecRuleFn(Protocol):
|
|
|
|
def __call__(
|
|
self,
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...],
|
|
**params: Any,
|
|
) -> Sequence[pallas_core.BlockSpec]:
|
|
...
|
|
|
|
|
|
def register_pull_block_spec_rule(
|
|
prim: core.Primitive,
|
|
) -> Callable[[PullBlockSpecRuleFn], PullBlockSpecRuleFn]:
|
|
|
|
def wrapper(
|
|
f: PullBlockSpecRuleFn,
|
|
) -> PullBlockSpecRuleFn:
|
|
pull_block_spec_rules[prim] = f
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
|
|
# Primitive rule implementations
|
|
|
|
|
|
def _eltwise_eval_rule(prim, ctx, x, **params):
|
|
del ctx
|
|
return prim.bind(x, **params)
|
|
|
|
|
|
def _eltwise_pull_rule(
|
|
prim: core.Primitive,
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
**params,
|
|
) -> Sequence[pallas_core.BlockSpec]:
|
|
del prim, ctx, params
|
|
return [block_spec]
|
|
|
|
|
|
def _eltwise_usage_rule(
|
|
prim: core.Primitive, ctx: UsageRuleContext, used_out: set[Usage], **params
|
|
) -> Sequence[set[Usage]]:
|
|
del ctx, prim, params
|
|
return [used_out]
|
|
|
|
|
|
def _bcast_block_spec(
|
|
block_spec: pallas_core.BlockSpec, i: int
|
|
) -> pallas_core.BlockSpec:
|
|
def new_index_map(i, *args):
|
|
idx = block_spec.index_map(*args)
|
|
assert len(idx) == len(block_spec.block_shape)
|
|
idx = util.tuple_update(idx, i, 0)
|
|
return idx
|
|
|
|
new_block_shape = util.tuple_update(block_spec.block_shape, i, 1)
|
|
return pallas_core.BlockSpec(
|
|
new_block_shape, functools.partial(new_index_map, i)
|
|
)
|
|
|
|
|
|
def _binop_usage_rule(prim, ctx, used_out: set[Usage]):
|
|
del prim
|
|
if used_out == {Usage.SCALAR_PREFETCH}:
|
|
return [{Usage.SCALAR_PREFETCH}, {Usage.SCALAR_PREFETCH}]
|
|
elif used_out == {Usage.REGULAR}:
|
|
usage = [{Usage.REGULAR} for _ in ctx.avals_in]
|
|
return usage
|
|
else:
|
|
return [set()] * len(ctx.avals_in)
|
|
|
|
|
|
def _binop_eval_rule(prim, ctx, x, y):
|
|
del ctx
|
|
return prim.bind(x, y)
|
|
|
|
|
|
def _binop_pull_rule(prim, ctx: PullRuleContext, block_spec):
|
|
l_block_spec = block_spec
|
|
r_block_spec = block_spec
|
|
left_aval, right_aval = ctx.avals_in
|
|
assert isinstance(left_aval, core.ShapedArray)
|
|
assert isinstance(right_aval, core.ShapedArray)
|
|
|
|
@ctx.set_eval_function
|
|
def _eval_function(_, x, y):
|
|
sp_index = 0
|
|
if x is None:
|
|
x = ctx.scalar_prefetch_fn()[sp_index]
|
|
sp_index += 1
|
|
if y is None:
|
|
y = ctx.scalar_prefetch_fn()[sp_index]
|
|
return prim.bind(x, y)
|
|
|
|
if not right_aval.shape:
|
|
return [block_spec, pallas_core.no_block_spec]
|
|
if not left_aval.shape:
|
|
return [pallas_core.no_block_spec, block_spec]
|
|
for i, (l, r) in enumerate(
|
|
zip(left_aval.shape, right_aval.shape, strict=True)
|
|
):
|
|
if l == 1 and r != 1:
|
|
l_block_spec = _bcast_block_spec(l_block_spec, i)
|
|
if r == 1 and l != 1:
|
|
r_block_spec = _bcast_block_spec(r_block_spec, i)
|
|
|
|
return [l_block_spec, r_block_spec]
|
|
|
|
|
|
def register_binop_rule(prim: core.Primitive):
|
|
register_pull_block_spec_rule(prim)(functools.partial(_binop_pull_rule, prim))
|
|
register_usage_rule(prim)(functools.partial(_binop_usage_rule, prim))
|
|
register_eval_rule(prim)(functools.partial(_binop_eval_rule, prim))
|
|
|
|
|
|
register_binop_rule(lax.mul_p)
|
|
register_binop_rule(lax.add_p)
|
|
register_binop_rule(lax.sub_p)
|
|
register_binop_rule(lax.div_p)
|
|
register_binop_rule(lax.max_p)
|
|
register_binop_rule(lax.lt_p)
|
|
register_binop_rule(lax.le_p)
|
|
register_binop_rule(lax.eq_p)
|
|
register_binop_rule(lax.gt_p)
|
|
register_binop_rule(lax.ge_p)
|
|
register_binop_rule(lax.and_p)
|
|
register_binop_rule(ad_util.add_any_p)
|
|
|
|
|
|
@register_eval_rule(lax.select_n_p)
|
|
def _select_n_eval_rule(ctx: KernelEvalContext, *args):
|
|
return jax.lax.select_n(*args)
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.select_n_p)
|
|
def _select_n_pull_block_spec_rule(
|
|
ctx: PullRuleContext, block_spec: pallas_core.BlockSpec
|
|
) -> Sequence[pallas_core.BlockSpec]:
|
|
in_aval = ctx.avals_in[0]
|
|
assert isinstance(in_aval, core.ShapedArray)
|
|
if in_aval.shape:
|
|
return [block_spec] * len(ctx.avals_in)
|
|
return [pallas_core.no_block_spec, *[block_spec] * (len(ctx.avals_in) - 1)]
|
|
|
|
|
|
@register_eval_rule(lax.squeeze_p)
|
|
def _squeeze_eval_rule(ctx: KernelEvalContext, x: jax.Array, **params: Any):
|
|
del ctx, params
|
|
return x
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.squeeze_p)
|
|
def _squeeze_block_spec(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
dimensions: tuple[int, ...],
|
|
) -> Sequence[pallas_core.BlockSpec]:
|
|
del ctx
|
|
if block_spec is pallas_core.no_block_spec:
|
|
return [pallas_core.no_block_spec]
|
|
|
|
def new_index_map(*args):
|
|
idx = tuple(block_spec.index_map(*args))
|
|
assert len(idx) == len(block_spec.block_shape)
|
|
for dim in dimensions:
|
|
idx = util.tuple_insert(idx, dim, 0)
|
|
return idx
|
|
|
|
new_block_shape = tuple(block_spec.block_shape)
|
|
for dim in dimensions:
|
|
new_block_shape = util.tuple_insert(new_block_shape, dim, None)
|
|
|
|
return [pallas_core.BlockSpec(new_block_shape, new_index_map)]
|
|
|
|
|
|
@register_eval_rule(lax.slice_p)
|
|
def _slice_eval_rule(ctx, x, **params):
|
|
del params
|
|
out_block_shape = ctx.out_block_specs[0].block_shape
|
|
assert len(x.shape) == sum(1 for bs in out_block_shape if bs is not None)
|
|
return x
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.slice_p)
|
|
def _slice_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
start_indices: tuple[int, ...],
|
|
limit_indices: tuple[int, ...],
|
|
strides: tuple[int, ...] | None,
|
|
):
|
|
del ctx
|
|
if strides is not None:
|
|
raise NotImplementedError('strides are not supported yet')
|
|
slice_sizes = tuple(
|
|
int(end - start) for start, end in zip(start_indices, limit_indices)
|
|
)
|
|
for bs, slice_start, slice_size in zip(
|
|
block_spec.block_shape, start_indices, slice_sizes
|
|
):
|
|
if bs is None:
|
|
continue
|
|
assert slice_start % bs == 0, (start_indices, block_spec.block_shape)
|
|
assert slice_size % bs == 0, (slice_sizes, block_spec.block_shape)
|
|
offsets = tuple(
|
|
slice_start // bs if bs is not None else slice_start
|
|
for slice_start, bs in zip(start_indices, block_spec.block_shape)
|
|
)
|
|
|
|
def _offset(x, i):
|
|
return x + i if i != 0 else x
|
|
|
|
def new_index_map(*args):
|
|
idx = block_spec.index_map(*args)
|
|
assert len(idx) == len(block_spec.block_shape)
|
|
return tuple(_offset(i, o) for i, o in zip(idx, offsets))
|
|
|
|
return [pallas_core.BlockSpec(block_spec.block_shape, new_index_map)]
|
|
|
|
|
|
@register_usage_rule(lax.dynamic_slice_p)
|
|
def _dynamic_slice_usage_rule(ctx, used_out: set[Usage], **params):
|
|
del params
|
|
if used_out == {Usage.SCALAR_PREFETCH}:
|
|
raise NotImplementedError('scalar prefetch not supported yet')
|
|
elif used_out == {Usage.REGULAR}:
|
|
usage = [used_out] + [{Usage.SCALAR_PREFETCH}] * (len(ctx.avals_in) - 1)
|
|
return usage
|
|
else:
|
|
return [set()] * len(ctx.avals_in)
|
|
|
|
|
|
def _offset(x, i, s):
|
|
from jax.experimental import checkify
|
|
|
|
if s is not None:
|
|
pred = i % s == 0
|
|
if isinstance(pred, jax.Array):
|
|
checkify.check(i % s == 0, 'Invalid index', debug=True)
|
|
else:
|
|
if not pred:
|
|
raise ValueError('Invalid index')
|
|
offset = jax.lax.div(i, s) if s is not None else i
|
|
return x + offset
|
|
|
|
|
|
@register_eval_rule(lax.dynamic_slice_p)
|
|
def _dynamic_slice_eval_rule(ctx, x, *args, **params):
|
|
del ctx, params
|
|
return x
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.dynamic_slice_p)
|
|
def _dynamic_slice_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
slice_sizes: tuple[int, ...],
|
|
):
|
|
del slice_sizes
|
|
|
|
def new_index_map(*args):
|
|
slice_starts = ctx.scalar_prefetch_fn()
|
|
if len(slice_starts) != len(block_spec.block_shape):
|
|
raise ValueError(
|
|
f'Expected {len(block_spec.block_shape)} slice starts, got'
|
|
f' {len(slice_starts)}'
|
|
)
|
|
idx = block_spec.index_map(*args)
|
|
assert len(idx) == len(block_spec.block_shape)
|
|
|
|
# Once we have the indices, we need to offset them by the dynamic slice
|
|
# indices. The dynamic slice indices index the full array. For example,
|
|
# let's say we have a [l, m, n] array and are provided 3 dynamic slice
|
|
# start indices [i, j, k] with sizes [s_l, s_m, s_n]. To perform the slice,
|
|
# we need to compute the indices of the block that correspond to that slice
|
|
# in the [l, m, n] array. If we have block sizes [b_l, b_m, b_n], we require
|
|
# that i % b_l == 0, j % b_m == 0, k % b_n == 0 and the slice sizes are
|
|
# multiples of the block sizes. The indices of the block that correspond to
|
|
# the slice are then given by (i // b_l, j // b_m, k // b_n).
|
|
# We then add these block indices to block indices produced by the index
|
|
# map.
|
|
block_indices = tuple(
|
|
_offset(i, o, s)
|
|
for i, o, s in zip(
|
|
idx, slice_starts, block_spec.block_shape, strict=True
|
|
)
|
|
)
|
|
return block_indices
|
|
|
|
new_block_spec = pallas_core.BlockSpec(block_spec.block_shape, new_index_map)
|
|
return [new_block_spec] + [pallas_core.no_block_spec] * (
|
|
len(ctx.avals_in) - 1
|
|
)
|
|
|
|
|
|
@register_eval_rule(lax.concatenate_p)
|
|
def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension):
|
|
# We now handle the case where each of the concatenated array dimensions
|
|
# divides the block size.
|
|
block_spec = ctx.out_block_specs[0]
|
|
block_shape = block_spec.block_shape
|
|
block_dim = block_shape[dimension]
|
|
if block_dim is None:
|
|
block_dim = 1
|
|
|
|
if block_dim == sum(aval.shape[dimension] for aval in ctx.avals_in): # pytype: disable=attribute-error
|
|
# Handle special case if the block contains all of the concatenated
|
|
# array.
|
|
return jax.lax.concatenate(args, dimension=dimension)
|
|
|
|
num_blocks = []
|
|
for aval in ctx.avals_in:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
if aval.shape[dimension] % block_dim != 0:
|
|
raise ValueError(
|
|
f'Shape along concat dimension {dimension} must be divisible by the'
|
|
f' block shape {block_shape[dimension]} for all children. Got shape'
|
|
f' {aval.shape}.'
|
|
)
|
|
num_blocks.append(aval.shape[dimension] // block_dim)
|
|
ends = np.cumsum(num_blocks).astype(np.int32)
|
|
starts = np.concatenate(([0], ends[:-1])).astype(np.int32)
|
|
|
|
block_indices = ctx.get_out_block_indices()[0]
|
|
block_idx = block_indices[dimension]
|
|
valid_index = 0
|
|
for i in range(len(ctx.avals_in)):
|
|
start, end = starts[i], ends[i]
|
|
is_valid = (start <= block_idx) & (block_idx < end)
|
|
valid_index = jax.lax.select(is_valid, i, valid_index)
|
|
out_dtype = args[0].dtype
|
|
args = [a.astype(jnp.float32) if a.dtype == jnp.bfloat16 else a for a in args]
|
|
valid_block = jax.lax.select_n(valid_index, *args)
|
|
return valid_block.astype(out_dtype)
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.concatenate_p)
|
|
def _concatenate_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
dimension: int,
|
|
):
|
|
block_shape = block_spec.block_shape
|
|
num_blocks = []
|
|
block_dim = block_shape[dimension]
|
|
if block_dim is None:
|
|
block_dim = 1
|
|
if block_dim == sum(aval.shape[dimension] for aval in ctx.avals_in): # pytype: disable=attribute-error
|
|
# Handle special case if the block contains all of the concatenated
|
|
# array.
|
|
new_shapes = [
|
|
util.tuple_update(
|
|
block_spec.block_shape, dimension, aval.shape[dimension] # pytype: disable=attribute-error
|
|
)
|
|
for aval in ctx.avals_in
|
|
]
|
|
new_block_specs = [
|
|
block_spec.replace(block_shape=shape) for shape in new_shapes
|
|
]
|
|
return new_block_specs
|
|
|
|
# We now handle the case where each of the concatenated array dimensions
|
|
# divides the block size.
|
|
for aval in ctx.avals_in:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
if aval.shape[dimension] % block_dim != 0:
|
|
raise ValueError(
|
|
f'Shape along concat dimension {dimension} must be divisible by the'
|
|
f' block shape {block_shape[dimension]} for all children. Got shape'
|
|
f' {aval.shape}.'
|
|
)
|
|
num_blocks.append(aval.shape[dimension] // block_dim)
|
|
ends = np.cumsum(num_blocks).astype(np.int32)
|
|
starts = np.concatenate(([0], ends[:-1])).astype(np.int32)
|
|
|
|
def make_block_spec(child_index: int):
|
|
def new_index_map(*args):
|
|
idx = block_spec.index_map(*args)
|
|
block_idx = idx[dimension]
|
|
is_valid = (starts[child_index] <= block_idx) & (
|
|
block_idx < ends[child_index]
|
|
)
|
|
padding_index = jnp.where(
|
|
block_idx < starts[child_index], 0, num_blocks[child_index] - 1
|
|
)
|
|
block_idx = jnp.where(
|
|
is_valid, block_idx - starts[child_index], padding_index
|
|
)
|
|
return util.tuple_update(idx, dimension, block_idx)
|
|
|
|
return pallas_core.BlockSpec(block_spec.block_shape, new_index_map)
|
|
|
|
return [make_block_spec(i) for i in range(len(ctx.avals_in))]
|
|
|
|
|
|
@register_usage_rule(lax.broadcast_in_dim_p)
|
|
def _broadcast_in_dim_usage_rule(ctx, used_out: set[Usage], **params):
|
|
del params
|
|
if used_out == {Usage.SCALAR_PREFETCH}:
|
|
raise NotImplementedError('scalar prefetch not supported yet')
|
|
elif used_out == {Usage.REGULAR}:
|
|
return [
|
|
{Usage.SCALAR_PREFETCH}
|
|
if not ctx.avals_in[0].shape
|
|
else {Usage.REGULAR}
|
|
]
|
|
else:
|
|
return [set()]
|
|
|
|
|
|
@register_eval_rule(lax.broadcast_in_dim_p)
|
|
def _broadcast_in_dim_eval_rule(
|
|
eval_ctx: KernelEvalContext, x, broadcast_dimensions, **params
|
|
):
|
|
if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error
|
|
# Scalar -> Array broadcast
|
|
block_spec = eval_ctx.out_block_specs[0]
|
|
shape = tuple(s for s in block_spec.block_shape if s is not None)
|
|
return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape)
|
|
return x
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.broadcast_in_dim_p)
|
|
def _broadcast_in_dim_pull_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
shape: tuple[int, ...],
|
|
broadcast_dimensions: tuple[int, ...],
|
|
sharding: jax.sharding.Sharding,
|
|
):
|
|
del shape, sharding
|
|
|
|
if not ctx.avals_in[0].shape: # pytype: disable=attribute-error
|
|
return [pallas_core.no_block_spec]
|
|
|
|
def new_index_map(*args):
|
|
idx = block_spec.index_map(*args)
|
|
return tuple(idx[i] for i in broadcast_dimensions)
|
|
|
|
new_block_shape = tuple(
|
|
block_spec.block_shape[i] for i in broadcast_dimensions
|
|
)
|
|
return [pallas_core.BlockSpec(new_block_shape, new_index_map)]
|
|
|
|
|
|
@register_eval_rule(lax.transpose_p)
|
|
def _transpose_eval_rule(
|
|
eval_ctx: KernelEvalContext, x, permutation: tuple[int, ...]
|
|
):
|
|
block_spec = eval_ctx.out_block_specs[0]
|
|
block_shape = block_spec.block_shape
|
|
block_shape_no_nones = tuple(bs for bs in block_shape if bs is not None)
|
|
block_dims_iter = iter(range(len(block_shape_no_nones)))
|
|
expanded_block_dims = [
|
|
None if bs is None else next(block_dims_iter) for bs in block_shape
|
|
]
|
|
assert next(block_dims_iter, None) is None
|
|
permuted_block_dims = [expanded_block_dims[p] for p in permutation]
|
|
new_permutation = [p for p in permuted_block_dims if p is not None]
|
|
return jax.lax.transpose(x, permutation=new_permutation)
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.transpose_p)
|
|
def _transpose_pull_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
permutation: tuple[int, ...],
|
|
):
|
|
|
|
block_shape = block_spec.block_shape
|
|
new_shape = tuple(block_shape[i] for i in permutation)
|
|
aval_in = ctx.avals_in[0]
|
|
assert isinstance(aval_in, core.ShapedArray)
|
|
assert len(block_shape) == len(aval_in.shape)
|
|
if set(permutation[-2:]) != {permutation[-1], permutation[-2]}:
|
|
raise NotImplementedError(
|
|
'Cannot permute last two dimensions with leading dimensions.'
|
|
)
|
|
|
|
def new_index_map(*args):
|
|
original_idxs = block_spec.index_map(*args)
|
|
return tuple(original_idxs[i] for i in permutation)
|
|
|
|
return [pallas_core.BlockSpec(new_shape, new_index_map)]
|
|
|
|
|
|
@register_eval_rule(lax.convert_element_type_p)
|
|
def _convert_element_type_eval_rule(
|
|
eval_ctx: KernelEvalContext, x, new_dtype, **params
|
|
):
|
|
return jax.lax.convert_element_type(x, new_dtype)
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.convert_element_type_p)
|
|
def _convert_element_type_pull_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
new_dtype: jnp.dtype,
|
|
weak_type: bool,
|
|
sharding: jax.sharding.Sharding,
|
|
):
|
|
del ctx, new_dtype, weak_type, sharding
|
|
return [block_spec]
|
|
|
|
|
|
@register_eval_rule(lax.iota_p)
|
|
def _iota_eval_rule(
|
|
eval_ctx: KernelEvalContext, *, dimension, shape, dtype, sharding
|
|
):
|
|
del sharding
|
|
block_spec = eval_ctx.out_block_specs[0]
|
|
block_idx = eval_ctx.get_out_block_indices()[0]
|
|
assert len(block_idx) == len(shape)
|
|
iota_shape = tuple(s for s in block_spec.block_shape if s is not None)
|
|
dim_ = dimension - sum(s is None for s in block_spec.block_shape[:dimension])
|
|
local_iota = jax.lax.broadcasted_iota(dtype, iota_shape, dim_)
|
|
return local_iota + block_idx[dimension] * block_spec.block_shape[dimension]
|
|
|
|
|
|
@register_pull_block_spec_rule(lax.iota_p)
|
|
def _iota_pull_rule(
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
dtype: jnp.dtype,
|
|
dimension: int,
|
|
shape: tuple[int, ...],
|
|
sharding: jax.sharding.Sharding,
|
|
):
|
|
del ctx, sharding, dtype, shape
|
|
if block_spec.block_shape[dimension] is None:
|
|
raise ValueError(
|
|
f'Cannot pull iota along dimension {dimension} with None block size.'
|
|
)
|
|
return []
|
|
|
|
|
|
@register_usage_rule(pjit.pjit_p)
|
|
def _jit_usage_rule(
|
|
ctx, used_out: list[set[Usage]], *, jaxpr: core.ClosedJaxpr, **_
|
|
):
|
|
read_usage_env = compute_usage(jaxpr.jaxpr, used_out)
|
|
in_usages = util.safe_map(read_usage_env, jaxpr.jaxpr.invars)
|
|
return in_usages
|
|
|
|
|
|
@register_eval_rule(pjit.pjit_p)
|
|
def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs):
|
|
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
|
if consts:
|
|
raise NotImplementedError('pjit with consts not supported yet')
|
|
out_tree = tree_util.tree_structure(tuple(jaxpr.outvars))
|
|
in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {}))
|
|
read_usage_env = compute_usage(jaxpr, ctx.out_usages)
|
|
_, env, _ = _pull_block_spec(
|
|
jaxpr,
|
|
ctx.out_block_specs,
|
|
ctx.out_usages,
|
|
scalar_prefetch_handler=ctx.scalar_prefetch_handler,
|
|
grid=ctx.grid,
|
|
)
|
|
kernel_fn = make_kernel_function(
|
|
jaxpr,
|
|
in_tree,
|
|
out_tree,
|
|
read_usage_env,
|
|
ctx.in_block_specs,
|
|
env,
|
|
ctx.scalar_prefetch_handler,
|
|
ctx.grid,
|
|
)
|
|
return kernel_fn(ctx.get_program_ids(), ctx.scalar_prefetch, *args)
|
|
|
|
|
|
@register_pull_block_spec_rule(pjit.pjit_p)
|
|
def _jit_pull_block_spec_rule(
|
|
ctx: PullRuleContext, out_block_specs, *, jaxpr, **kwargs
|
|
):
|
|
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
|
if consts:
|
|
raise NotImplementedError('pjit with consts not supported yet')
|
|
in_block_specs, _, _ = _pull_block_spec(
|
|
jaxpr,
|
|
out_block_specs,
|
|
ctx.out_usages,
|
|
scalar_prefetch_handler=ctx.scalar_prefetch_handler,
|
|
grid=ctx.grid,
|
|
)
|
|
return in_block_specs
|
|
|
|
|
|
@register_usage_rule(custom_derivatives.custom_jvp_call_p)
|
|
def _custom_jvp_call_usage_rule(
|
|
ctx, used_out: list[set[Usage]], *, call_jaxpr: core.ClosedJaxpr, **_
|
|
):
|
|
del ctx
|
|
read_usage_env = compute_usage(call_jaxpr.jaxpr, used_out)
|
|
in_usages = util.safe_map(read_usage_env, call_jaxpr.jaxpr.invars)
|
|
return in_usages
|
|
|
|
|
|
@register_eval_rule(custom_derivatives.custom_jvp_call_p)
|
|
def _custom_jvp_call_eval_rule(
|
|
ctx: KernelEvalContext, *args, call_jaxpr: core.ClosedJaxpr, **kwargs
|
|
):
|
|
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
|
if consts:
|
|
raise NotImplementedError('custom_jvp_call with consts not supported yet')
|
|
out_tree = tree_util.tree_structure(tuple(jaxpr.outvars))
|
|
in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {}))
|
|
read_usage_env = compute_usage(jaxpr, ctx.out_usages)
|
|
_, env, _ = _pull_block_spec(
|
|
jaxpr,
|
|
ctx.out_block_specs,
|
|
ctx.out_usages,
|
|
scalar_prefetch_handler=ctx.scalar_prefetch_handler,
|
|
grid=ctx.grid,
|
|
)
|
|
kernel_fn = make_kernel_function(
|
|
jaxpr,
|
|
in_tree,
|
|
out_tree,
|
|
read_usage_env,
|
|
ctx.in_block_specs,
|
|
env,
|
|
ctx.scalar_prefetch_handler,
|
|
ctx.grid,
|
|
)
|
|
return kernel_fn(ctx.get_program_ids(), ctx.scalar_prefetch, *args)
|
|
|
|
|
|
@register_pull_block_spec_rule(custom_derivatives.custom_jvp_call_p)
|
|
def _custom_jvp_call_pull_block_spec_rule(
|
|
ctx: PullRuleContext, out_block_specs, *, call_jaxpr, **kwargs
|
|
):
|
|
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
|
if consts:
|
|
raise NotImplementedError('custom_jvp_call with consts not supported yet')
|
|
in_block_specs, _, _ = _pull_block_spec(
|
|
jaxpr,
|
|
out_block_specs,
|
|
ctx.out_usages,
|
|
scalar_prefetch_handler=ctx.scalar_prefetch_handler,
|
|
grid=ctx.grid,
|
|
)
|
|
return in_block_specs
|
|
|
|
|
|
def push_block_spec(
|
|
f: Callable,
|
|
*in_spec_args,
|
|
**in_spec_kwargs,
|
|
):
|
|
def wrapper(*args, **kwargs):
|
|
flat_block_specs, in_tree_ = tree_util.tree_flatten(
|
|
(in_spec_args, in_spec_kwargs)
|
|
)
|
|
jaxpr, _, in_tree, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
|
|
if in_tree != in_tree_:
|
|
raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}')
|
|
out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs)
|
|
return tree_util.tree_unflatten(out_tree, out_bs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def _push_block_spec_jaxpr(
|
|
jaxpr: core.Jaxpr,
|
|
*flat_block_specs,
|
|
) -> tuple[pallas_core.BlockSpec, ...]:
|
|
num_inputs = len(jaxpr.invars)
|
|
if len(flat_block_specs) != num_inputs:
|
|
raise ValueError(
|
|
f'Expected {num_inputs} block specs, got {len(flat_block_specs)}'
|
|
)
|
|
|
|
env: dict[core.Var, pallas_core.BlockSpec | pallas_core.NoBlockSpec] = {}
|
|
|
|
for invar, bs in zip(jaxpr.invars, flat_block_specs, strict=True):
|
|
env[invar] = bs
|
|
for constvar in jaxpr.constvars:
|
|
env[constvar] = pallas_core.no_block_spec
|
|
|
|
def _read_block_spec(
|
|
atom: core.Atom,
|
|
) -> pallas_core.BlockSpec | pallas_core.NoBlockSpec:
|
|
if isinstance(atom, core.Literal):
|
|
return pallas_core.no_block_spec
|
|
return env[atom]
|
|
|
|
def _write_block_spec(
|
|
atom: core.Atom,
|
|
block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec,
|
|
):
|
|
if isinstance(atom, core.Literal):
|
|
return
|
|
env[atom] = block_spec
|
|
|
|
for eqn in jaxpr.eqns:
|
|
in_block_specs = tuple(util.safe_map(_read_block_spec, eqn.invars))
|
|
if all(bs is pallas_core.no_block_spec for bs in in_block_specs):
|
|
for outvar in eqn.outvars:
|
|
_write_block_spec(outvar, pallas_core.no_block_spec)
|
|
continue
|
|
rule = push_block_spec_rules.get(eqn.primitive, None)
|
|
if not rule:
|
|
raise NotImplementedError(eqn.primitive)
|
|
ctx = PushRuleContext(
|
|
avals_in=tuple(v.aval for v in eqn.invars),
|
|
avals_out=tuple(v.aval for v in eqn.outvars),
|
|
)
|
|
if eqn.primitive.multiple_results:
|
|
out_block_specs = rule(ctx, *in_block_specs, **eqn.params)
|
|
else:
|
|
out_block_specs = [rule(ctx, *in_block_specs, **eqn.params)]
|
|
|
|
util.safe_map(_write_block_spec, eqn.outvars, out_block_specs)
|
|
out_block_specs = tuple(util.safe_map(_read_block_spec, jaxpr.outvars))
|
|
valid_block_spec = [
|
|
bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec
|
|
][0]
|
|
out_block_specs = tuple(
|
|
valid_block_spec if obs is pallas_core.no_block_spec else obs
|
|
for obs in out_block_specs
|
|
)
|
|
if any(bs is pallas_core.no_block_spec for bs in out_block_specs):
|
|
raise ValueError('No block spec found for output')
|
|
return out_block_specs # pytype: disable=bad-return-type
|
|
|
|
|
|
push_block_spec_rules: dict[core.Primitive, PushBlockSpecRuleFn] = {}
|
|
|
|
|
|
class PushBlockSpecRuleFn(Protocol):
|
|
|
|
def __call__(
|
|
self,
|
|
ctx: PushRuleContext,
|
|
block_spec: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...],
|
|
**params: Any,
|
|
) -> pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...]:
|
|
...
|
|
|
|
|
|
def register_push_block_spec_rule(
|
|
prim: core.Primitive,
|
|
) -> Callable[[PushBlockSpecRuleFn], PushBlockSpecRuleFn]:
|
|
|
|
def wrapper(
|
|
f: PushBlockSpecRuleFn,
|
|
) -> PushBlockSpecRuleFn:
|
|
push_block_spec_rules[prim] = f
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
|
|
def _binop_push_rule(
|
|
prim: core.Primitive,
|
|
ctx: PullRuleContext,
|
|
left_block_spec: pallas_core.BlockSpec,
|
|
right_block_spec: pallas_core.BlockSpec,
|
|
**params: Any,
|
|
) -> Sequence[pallas_core.BlockSpec]:
|
|
del prim, params
|
|
left_aval, right_aval = ctx.avals_in
|
|
assert isinstance(left_aval, core.ShapedArray)
|
|
assert isinstance(right_aval, core.ShapedArray)
|
|
if left_block_spec is pallas_core.no_block_spec:
|
|
return right_block_spec
|
|
if right_block_spec is pallas_core.no_block_spec:
|
|
return left_block_spec
|
|
if right_block_spec != left_block_spec:
|
|
raise ValueError('Invalid block spec')
|
|
return left_block_spec
|
|
|
|
|
|
register_binop_push_rule = lambda prim: register_push_block_spec_rule(prim)(
|
|
functools.partial(_binop_push_rule, prim)
|
|
)
|
|
register_binop_push_rule(lax.mul_p)
|
|
register_binop_push_rule(lax.add_p)
|
|
register_binop_push_rule(lax.sub_p)
|
|
register_binop_push_rule(lax.div_p)
|
|
register_binop_push_rule(lax.max_p)
|
|
register_binop_push_rule(lax.lt_p)
|
|
register_binop_push_rule(lax.eq_p)
|
|
register_binop_push_rule(lax.gt_p)
|
|
register_binop_push_rule(lax.and_p)
|
|
register_binop_push_rule(ad_util.add_any_p)
|
|
|
|
|
|
def _eltwise_push_rule(
|
|
prim: core.Primitive,
|
|
ctx: PullRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
**params,
|
|
) -> pallas_core.BlockSpec:
|
|
del prim, ctx, params
|
|
return block_spec
|
|
|
|
|
|
@register_push_block_spec_rule(lax.transpose_p)
|
|
def _transpose_push_rule(
|
|
ctx: PushRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
permutation: tuple[int, ...],
|
|
) -> pallas_core.BlockSpec:
|
|
del ctx
|
|
block_shape = block_spec.block_shape
|
|
new_shape = [block_shape[i] for i in permutation]
|
|
if set(permutation[-2:]) != {permutation[-1], permutation[-2]}:
|
|
raise NotImplementedError(
|
|
'Cannot permute last two dimensions with leading dimensions.'
|
|
)
|
|
|
|
def new_index_map(*args):
|
|
original_idxs = block_spec.index_map(*args)
|
|
return tuple(original_idxs[i] for i in permutation)
|
|
|
|
return pallas_core.BlockSpec(new_shape, new_index_map)
|
|
|
|
|
|
@register_push_block_spec_rule(lax.convert_element_type_p)
|
|
def _convert_element_type_push_rule(
|
|
ctx: PushRuleContext,
|
|
block_spec: pallas_core.BlockSpec,
|
|
*,
|
|
new_dtype: jnp.dtype,
|
|
weak_type: bool,
|
|
sharding: jax.sharding.Sharding,
|
|
):
|
|
del ctx, new_dtype, weak_type, sharding
|
|
return block_spec
|
|
|
|
|
|
@register_push_block_spec_rule(lax.select_n_p)
|
|
def _select_n_push_rule(
|
|
ctx: PushRuleContext,
|
|
*args: pallas_core.BlockSpec,
|
|
):
|
|
del ctx
|
|
block_specs = [b for b in args if b is not pallas_core.no_block_spec]
|
|
if len(block_specs) > 1:
|
|
raise NotImplementedError('select_n with multiple inputs not supported yet')
|
|
return block_specs[0]
|
|
|
|
|
|
@register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p)
|
|
def _custom_jvp_call_push_rule(
|
|
ctx, *block_specs, call_jaxpr: core.ClosedJaxpr, **_
|
|
):
|
|
assert not call_jaxpr.consts
|
|
return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs)
|
|
|
|
|
|
@register_push_block_spec_rule(pjit.pjit_p)
|
|
def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_):
|
|
assert not jaxpr.consts
|
|
return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs)
|
|
|
|
|
|
def register_eltwise_rule(prim: core.Primitive):
|
|
register_pull_block_spec_rule(prim)(
|
|
functools.partial(_eltwise_pull_rule, prim)
|
|
)
|
|
register_usage_rule(prim)(functools.partial(_eltwise_usage_rule, prim))
|
|
register_eval_rule(prim)(functools.partial(_eltwise_eval_rule, prim))
|
|
register_push_block_spec_rule(prim)(
|
|
functools.partial(_eltwise_push_rule, prim)
|
|
)
|
|
|
|
|
|
register_eltwise_rule(lax.exp_p)
|
|
register_eltwise_rule(lax.tanh_p)
|
|
register_eltwise_rule(lax.sin_p)
|
|
register_eltwise_rule(lax.cos_p)
|
|
register_eltwise_rule(lax.sqrt_p)
|
|
register_eltwise_rule(lax.rsqrt_p)
|
|
register_eltwise_rule(lax.log_p)
|
|
register_eltwise_rule(lax.integer_pow_p)
|