[Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels

PiperOrigin-RevId: 733112191
This commit is contained in:
Sharad Vikram 2025-03-03 17:05:09 -08:00 committed by jax authors
parent 2c7043f63d
commit 0b6c355083
10 changed files with 4281 additions and 0 deletions

121
jax/_src/pallas/fuser/BUILD Normal file
View File

@ -0,0 +1,121 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"py_deps",
"pytype_strict_library",
)
package(
default_applicable_licenses = [],
default_visibility = [
"//jax:internal",
],
)
pytype_strict_library(
name = "fuser",
srcs = [
"__init__.py",
],
deps = [
":block_spec",
":fusable",
":fusion",
":jaxpr_fusion",
],
)
pytype_strict_library(
name = "block_spec",
srcs = [
"block_spec.py",
],
deps = [
"//jax",
"//jax:ad_util",
"//jax:api_util",
"//jax:core",
"//jax:partial_eval",
"//jax:tree_util",
"//jax:util",
"//jax/_src/pallas",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "fusable",
srcs = [
"fusable.py",
],
deps = [
":fusion",
"//jax",
"//jax:api_util",
"//jax:core",
"//jax:mlir",
"//jax:partial_eval",
"//jax:tree_util",
"//jax:util",
],
)
pytype_strict_library(
name = "fusion",
srcs = [
"fusion.py",
],
deps = [
"//jax",
"//jax:util",
],
)
pytype_strict_library(
name = "jaxpr_fusion",
srcs = [
"jaxpr_fusion.py",
],
deps = [
":fusable",
":fusable_dtype",
":fusion",
"//jax",
"//jax:api_util",
"//jax:core",
"//jax:partial_eval",
"//jax:tree_util",
],
)
pytype_strict_library(
name = "fusable_dtype",
srcs = [
"fusable_dtype.py",
],
deps = [
":block_spec",
":fusable",
"//jax",
"//jax:api_util",
"//jax:core",
"//jax:dtypes",
"//jax:partial_eval",
"//jax:source_info_util",
"//jax:tree_util",
"//jax:util",
"//jax/_src/pallas",
],
)

View File

@ -0,0 +1,21 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.fusable import fusable as fusable
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fusable primitive."""
import jax
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.fuser import fusion as fusion_lib
fusable_p = jax_core.Primitive('fusable')
fusable_p.multiple_results = True
def _get_aval(x):
return jax_core.raise_to_shaped(jax_core.get_aval(x))
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
return fusion_lib.Fusion(
func=lambda: x,
in_type=((), {}),
out_type=jax.ShapeDtypeStruct(x.shape, x.dtype),
)
def fusable(f):
def wrapper(*args):
def wrapped(*args):
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
return f(*in_fusions, None)
flat_args, in_tree = tree_util.tree_flatten(args)
debug_info = api_util.debug_info('fusable', wrapped, args, {})
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
)
flat_avals = [_get_aval(x) for x in flat_args]
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
out = fusable_p.bind(
*consts,
*flat_args,
jaxpr=jaxpr,
num_consts=len(consts),
in_tree=in_tree,
out_tree=out_tree,
func=f,
)
return tree_util.tree_unflatten(out_tree, out)
return wrapper
@fusable_p.def_impl
def _(*consts_and_args, jaxpr, num_consts, **_):
consts, args = util.split_list(consts_and_args, [num_consts])
return jax_core.eval_jaxpr(jaxpr, consts, *args)
mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl))
@fusable_p.def_abstract_eval
def _(*args, jaxpr, **kwargs):
del args, kwargs
return [v.aval for v in jaxpr.outvars]

View File

@ -0,0 +1,465 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom fusable dtypes."""
import abc
import dataclasses
import functools
from typing import Any, Sequence, TypeVar
import jax
from jax._src import api_util
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.lax.control_flow import conditionals
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives as pallas_primitives
from jax._src.pallas.fuser import block_spec
from jax._src.pallas.fuser.fusable import fusable_p
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
# TODO(sharadmv): Enable type checking.
# mypy: ignore-errors
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
T = TypeVar("T")
_physicalize_rules = {}
pack_dtype_p = core.Primitive("pack_dtype")
@pack_dtype_p.def_abstract_eval
def pack_dtype_abstract_eval(*xs, dtype):
if dtypes.issubdtype(dtype, FusableElementDType):
return dtype.abstract_pack(*xs)
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
def pack(*xs, dtype):
return pack_dtype_p.bind(*xs, dtype=dtype)
unpack_dtype_p = core.Primitive("unpack_dtype")
unpack_dtype_p.multiple_results = True
@unpack_dtype_p.def_abstract_eval
def unpack_dtype_abstract_eval(x):
if dtypes.issubdtype(x.dtype, FusableElementDType):
return x.dtype.abstract_unpack(x)
elif isinstance(x.dtype, pallas_core.AbstractMemoryRef):
raise NotImplementedError()
raise ValueError("Attempted to unpack non-fusion dtype: {dtype}")
def unpack(x):
return unpack_dtype_p.bind(x)
class FusableElementDType(dtypes.extended):
"""Scalar dtype for fusable dtypes."""
pass
class FusableTyRules:
allow_conversion: bool = False
class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
"""Base class for fusable extended dtypes."""
_op_registry = {}
_rules = FusableTyRules
type = FusableElementDType
@abc.abstractmethod
def abstract_unpack(self, x) -> Sequence[Any]:
raise NotImplementedError()
@abc.abstractmethod
def abstract_pack(self, *xs):
raise NotImplementedError()
@classmethod
def register_op(cls, primitive):
def _register_fn(fn):
cls._op_registry[primitive] = fn
return _register_fn
@classmethod
def get_op_rule(cls, primitive):
return cls._op_registry.get(primitive)
@property
def name(self):
return str(self)
@abc.abstractmethod
def pull_block_spec_one_step(self, *args, **kwargs):
raise NotImplementedError()
def physicalize(f):
"""Runs a function that contains fusable extended dtypes."""
def wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError()
flattened_args, treedef = jax.tree.flatten(args)
debug_info = api_util.debug_info("physicalize", f, args, kwargs)
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(f, debug_info=debug_info), treedef
)
avals = [core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
out_flat = core.eval_jaxpr(
new_jaxpr.jaxpr, new_jaxpr.consts, *flattened_args
)
return tree_util.tree_unflatten(out_tree_thunk(), out_flat)
return wrapper
@util.weakref_lru_cache
def physicalize_closed_jaxpr(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
"""Replaces all extended dtypes with physical types in a jaxpr."""
fun = functools.partial(physicalize_interp, jaxpr.jaxpr, jaxpr.consts)
in_avals = [_physical_aval(aval) for aval in jaxpr.in_avals]
flat_avals, treedef = tree_util.tree_flatten(in_avals)
debug_info = api_util.debug_info("physicalize_closed_jaxpr", fun, (), {})
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info), treedef
)
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
assert len(new_jaxpr.constvars) == len(consts), "Mismatched consts"
return core.ClosedJaxpr(new_jaxpr, consts)
def _physical_aval(aval):
if isinstance(aval, core.ShapedArray):
return core.ShapedArray(aval.shape, aval.dtype)
if isinstance(aval, state.AbstractRef):
if isinstance(aval.dtype, FusionDType):
unpacked = aval.dtype.abstract_unpack(aval.inner_aval)
return tuple(aval.update(inner_aval=u) for u in unpacked)
return aval
return aval
def physicalize_jaxpr(jaxpr: core.Jaxpr) -> core.Jaxpr:
"""Replaces all extended dtypes with physical types in a jaxpr."""
def _flat_jaxpr_eval(consts, args):
return physicalize_interp(jaxpr, consts, *args)
in_avals = [_physical_aval(v.aval) for v in jaxpr.invars]
const_avals = [_physical_aval(v.aval) for v in jaxpr.constvars]
flat_avals, treedef = jax.tree.flatten((const_avals, in_avals))
debug_info = api_util.debug_info(
"physicalize_jaxpr", _flat_jaxpr_eval, (), {}
)
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(_flat_jaxpr_eval, debug_info=debug_info), treedef
)
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
assert not consts
new_jaxpr = pe.convert_invars_to_constvars(
new_jaxpr, len(tree_util.tree_leaves(const_avals))
)
return new_jaxpr
@dataclasses.dataclass
class Context:
avals_in: Sequence[Any]
avals_out: Sequence[Any]
def physicalize_interp(
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
):
"""Physicalizes a jaxpr by replacing fusable dtypes with physical types."""
# TODO: Merge into JAX core.
env: dict[core.Var, Any] = {}
def read_env(var: core.Atom):
if isinstance(var, core.Literal):
return var.val
return env[var]
def write_env(var: core.Var, val: Any):
env[var] = val
map(write_env, jaxpr.constvars, consts)
assert len(jaxpr.invars) == len(
args
), f"Length mismatch: {jaxpr.invars} != {args}"
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = list(map(read_env, eqn.invars))
avals_in = tuple(x.aval for x in eqn.invars)
name_stack = (
source_info_util.current_name_stack() + eqn.source_info.name_stack
)
with (
source_info_util.user_context(
eqn.source_info.traceback, name_stack=name_stack
),
eqn.ctx.manager,
):
# need to check types and then invoke the correct rule.
in_types = [aval.dtype for aval in avals_in] # pytype: disable=attribute-error
ctx = Context(
avals_in=avals_in, avals_out=[var.aval for var in eqn.outvars]
)
custom_rule = _phys_find_rule(eqn.primitive, in_types)
if custom_rule:
outvals = custom_rule(ctx, *invals, **eqn.params)
else:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
if eqn.primitive.multiple_results:
assert len(outvals) == len(eqn.outvars)
map(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
return map(read_env, jaxpr.outvars)
def _phys_find_rule(primitive, types: Sequence[dtypes.DType]):
"""Finds the physicalization rule for a primitive."""
if primitive in _physicalize_rules:
return _physicalize_rules[primitive]
fusion_types = {type_ for type_ in types if isinstance(type_, FusionDType)}
if len(fusion_types) == 0:
return None
elif len(fusion_types) > 1:
raise ValueError(f"Multiple fusion types for primitive: {fusion_types}")
fusion_type = fusion_types.pop()
if primitive not in fusion_type._op_registry:
raise ValueError(
f"No implementation found for primitive {primitive} "
f"for custom type {fusion_type}"
)
return fusion_type.get_op_rule(primitive)
def _assert_no_fusion_types(avals: Sequence[core.AbstractValue]):
for aval in avals:
if isinstance(aval, (core.ShapedArray, state.AbstractRef)):
if isinstance(aval.dtype, FusionDType):
raise NotImplementedError(f"Fusion type found in avals: {avals}")
def _pallas_call_physicalize_rule(
ctx: Context, *args, jaxpr, grid_mapping: pallas_core.GridMapping, **kwargs
):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
with grid_mapping.trace_env():
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, ()))
num_new_vals = len(new_jaxpr.jaxpr.invars) - len(jaxpr.invars)
grid_mapping = grid_mapping.replace(
num_scratch_operands=grid_mapping.num_scratch_operands + num_new_vals
)
return pallas_call.pallas_call_p.bind(
*args, jaxpr=new_jaxpr.jaxpr, grid_mapping=grid_mapping, **kwargs
)
_physicalize_rules[pallas_call.pallas_call_p] = _pallas_call_physicalize_rule
def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs):
_assert_no_fusion_types(ctx.avals_out)
physicalized_branches = [
physicalize_closed_jaxpr(branch) for branch in branches
]
flat_args = jax.tree.leaves(args)
return conditionals.cond_p.bind(
*flat_args, branches=physicalized_branches, **kwargs
)
_physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule
def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
jaxpr = physicalize_jaxpr(jaxpr)
return state_discharge.run_state_p.bind(
*args,
jaxpr=jaxpr,
which_linear=which_linear,
is_initialized=is_initialized,
)
_physicalize_rules[state_discharge.run_state_p] = _run_state_rule
def _core_map_rule(ctx: Context, *args, jaxpr, **params):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
assert not jaxpr.invars
with core.extend_axis_env_nd(params["mesh"].shape.items()):
jaxpr = physicalize_jaxpr(jaxpr)
return pallas_core.core_map_p.bind(*args, jaxpr=jaxpr, **params)
_physicalize_rules[pallas_core.core_map_p] = _core_map_rule
def _run_scoped_rule(ctx: Context, *args, jaxpr, **params):
_assert_no_fusion_types(ctx.avals_out)
jaxpr = physicalize_jaxpr(jaxpr)
flat_args = tree_util.tree_leaves(args)
assert len(flat_args) == len(
jaxpr.constvars
), f"Length mismatch: {len(flat_args)=} != {len(jaxpr.constvars)=}"
return pallas_primitives.run_scoped_p.bind(*flat_args, jaxpr=jaxpr, **params)
_physicalize_rules[pallas_primitives.run_scoped_p] = _run_scoped_rule
def _scan_rule(ctx: Context, *args, jaxpr, **params):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
jaxpr = physicalize_closed_jaxpr(jaxpr)
return jax.lax.scan_p.bind(*args, jaxpr=jaxpr, **params)
_physicalize_rules[jax.lax.scan_p] = _scan_rule
def _while_rule(
ctx: Context, *args, body_jaxpr, cond_jaxpr, body_nconsts, **params
):
_assert_no_fusion_types(ctx.avals_out)
cond_avals = [v.aval for v in cond_jaxpr.jaxpr.invars]
_assert_no_fusion_types(cond_avals)
body_avals = [v.aval for v in body_jaxpr.jaxpr.invars]
_, body_in_avals = util.split_list(body_avals, [body_nconsts])
_assert_no_fusion_types(body_in_avals)
new_body_jaxpr = physicalize_closed_jaxpr(body_jaxpr)
new_num_body_consts = (
body_nconsts
+ len(new_body_jaxpr.jaxpr.invars)
- len(body_jaxpr.jaxpr.invars)
)
flat_args = tree_util.tree_leaves(args)
assert len(flat_args) == len(new_body_jaxpr.jaxpr.invars), (
f"Length mismatch: {len(flat_args)=} !="
f" {len(new_body_jaxpr.jaxpr.invars)=}"
)
return jax.lax.while_p.bind(
*flat_args,
body_jaxpr=new_body_jaxpr,
cond_jaxpr=cond_jaxpr,
body_nconsts=new_num_body_consts,
**params,
)
_physicalize_rules[jax.lax.while_p] = _while_rule
def _pack_rule(_, *args, dtype):
del dtype
return args
_physicalize_rules[pack_dtype_p] = _pack_rule
def _unpack_rule(_, arg):
return arg
_physicalize_rules[unpack_dtype_p] = _unpack_rule
def _swap_rule(ctx: Context, ref, val, *args, tree):
ref_aval, *_ = ctx.avals_in
if not isinstance(ref_aval.dtype, FusionDType):
return state_primitives.swap_p.bind(ref, val, *args, tree=tree)
return ref_aval.dtype.swap(ref, val, *args, tree=tree)
_physicalize_rules[state_primitives.swap_p] = _swap_rule
def _get_rule(ctx: Context, ref, *args, tree):
ref_aval, *_ = ctx.avals_in
if not isinstance(ref_aval.dtype, FusionDType):
return state_primitives.get_p.bind(ref, *args, tree=tree)
return ref_aval.dtype.get(ref, *args, tree=tree)
_physicalize_rules[state_primitives.get_p] = _get_rule
@block_spec.register_eval_rule(pack_dtype_p)
def _pack_dtype_eval_rule(eval_ctx: block_spec.KernelEvalContext, *args, dtype):
del eval_ctx
return pack_dtype_p.bind(*args, dtype=dtype)
@block_spec.register_pull_block_spec_rule(pack_dtype_p)
def _pack_dtype_pull_rule(
ctx: block_spec.PullRuleContext,
block_spec: pallas_core.BlockSpec,
*,
dtype: FusionDType,
):
del ctx
return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error
def _fusable_physicalize_rule(
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
):
consts, _ = util.split_list(consts_and_args, [num_consts])
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
return fusable_p.bind(
*consts_and_args,
jaxpr=new_jaxpr.jaxpr,
num_consts=num_consts,
in_tree=in_tree,
out_tree=out_tree,
func=func,
)
_physicalize_rules[fusable_p] = _fusable_physicalize_rule

View File

@ -0,0 +1,59 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fusion classes."""
from __future__ import annotations
import dataclasses
from typing import Any, Callable, Generic, ParamSpec, TypeVar
import jax
from jax._src import util
safe_map = util.safe_map
A = ParamSpec("A")
K = TypeVar("K")
@dataclasses.dataclass
class Fusion(Generic[A, K]):
func: Callable[A, K]
in_type: tuple[tuple[Any, ...], dict[str, Any]]
out_type: Any
def __call__(self, *args: A.args, **kwargs: A.kwargs):
return self.func(*args, **kwargs)
@property
def shape(self):
return jax.tree.map(lambda x: x.shape, self.out_type)
@property
def dtype(self):
return jax.tree.map(lambda x: x.dtype, self.out_type)
@property
def type(self):
return self.out_type
@property
def in_shape(self):
return jax.tree.map(lambda x: x.shape, self.in_type)
@property
def in_dtype(self):
return jax.tree.map(lambda x: x.dtype, self.in_type)

View File

@ -0,0 +1,147 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fuses a function."""
from typing import Any
import jax
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.fuser import fusable_dtype
from jax._src.pallas.fuser import fusion as fusion_lib
from jax._src.pallas.fuser.fusable import fusable_p
def _get_aval(x):
return jax_core.raise_to_shaped(jax_core.get_aval(x))
def fuse(f, *, physicalize: bool = False):
"""Fuses a function into a single fusable.
There should be a single call to a `fusable` inside the body of `f`. `fuse`
returns a transformed function that will fuse the surrounding computation into
the fusable and invoke it.
"""
def wrapper(*args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
debug_info = api_util.debug_info('fuse', f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
flat_avals = [_get_aval(x) for x in flat_args]
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
if physicalize:
wrapper = fusable_dtype.physicalize(wrapper)
return wrapper
_fusable: dict[jax_core.Primitive, Any] = {}
def construct_fusion(
candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs
) -> fusion_lib.Fusion:
flat_outvars, out_tree = tree_util.tree_flatten(outvars)
flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs))
new_jaxpr_no_dce = jaxpr.replace(
outvars=flat_outvars,
constvars=jaxpr.constvars + jaxpr.invars,
invars=flat_invars,
)
new_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
new_jaxpr_no_dce,
[True] * len(new_jaxpr_no_dce.outvars),
instantiate=[False] * len(new_jaxpr_no_dce.constvars)
+ [True] * len(new_jaxpr_no_dce.invars),
)
assert all(used_invars), new_jaxpr_no_dce
new_values = tuple(
c for used, c in zip(used_consts, candidate_values, strict=True) if used
)
kernel_in_tree = tree_util.tree_structure((invars, kwargs))
def _fn(*args, **kwargs):
flat_args, _ = tree_util.tree_flatten((args, kwargs))
out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
flat_in_type = [
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars
]
in_type = tree_util.tree_unflatten(kernel_in_tree, flat_in_type)
out_type = tree_util.tree_unflatten(
out_tree,
[jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars],
)
return fusion_lib.Fusion(_fn, in_type, out_type)
def fuse_jaxpr(
jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args
):
fusion_eqn_index = None
# Collect input fusions
for i, eqn in enumerate(jaxpr.eqns):
if eqn.primitive is fusable_p:
fusion_eqn_index = i
break
if fusion_eqn_index is None:
raise ValueError("No fusable eqn found")
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
candidate_values = [*consts, *args]
# Construct fusions for non-constant inputs to the fusable.
in_fusions_flat = [
construct_fusion(
candidate_values,
jaxpr.replace(
eqns=jaxpr.eqns[:fusion_eqn_index],
),
var,
)
for var in fusion_eqn.invars[fusion_eqn.params["num_consts"] :]
]
in_fusions = tree_util.tree_unflatten(
fusion_eqn.params["in_tree"], in_fusions_flat
)
out_fusion = construct_fusion(
candidate_values,
jaxpr.replace(
eqns=jaxpr.eqns[:fusion_eqn_index]
+ jaxpr.eqns[fusion_eqn_index + 1 :]
),
tree_util.tree_unflatten(out_tree, jaxpr.outvars),
tree_util.tree_unflatten(
fusion_eqn.params["out_tree"], fusion_eqn.outvars
),
)
# Run the fusable.
out = fusion_eqn.params["func"](*in_fusions, out_fusion)
# Now return the flattened output (the fuse_jaxpr caller should unflatten).
out_flat = tree_util.tree_leaves(out)
assert len(out_flat) == len(jaxpr.outvars)
return out_flat

View File

@ -589,3 +589,60 @@ jax_multiplatform_test(
"//jax:pallas_mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_multiplatform_test(
name = "fuser_block_spec_test",
srcs = [
"fuser_block_spec_test.py",
],
disable_configs = [
"cpu",
"cpu_shardy",
],
enable_backends = ["cpu"],
tags = [
"noasan",
"nomsan",
"notsan",
],
deps = [
"//jax:pallas",
"//jax/_src/pallas/fuser",
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_multiplatform_test(
name = "tpu_fusable_matmul_test",
srcs = ["tpu_fusable_matmul_test.py"],
disable_configs = [
"tpu_v3_1x1",
"tpu_pjrt_c_api",
"gpu_v100",
"gpu_v100_x32",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
"gpu_h100",
"cpu",
"cpu_x32",
"cpu_shardy",
],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v4_1x1",
"tpu_v5e",
"tpu_v5p_1x1",
"tpu_v6e_1x1",
],
shard_count = 4,
tags = [
"noasan",
"nomsan",
"notsan",
],
deps = [
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas/fuser",
] + py_deps("absl/testing") + py_deps("numpy"),
)

View File

@ -0,0 +1,776 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pull block spec."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.pallas.fuser import block_spec as block_spec_lib
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
class PullBlockSpecTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if config.enable_x64.value:
self.skipTest('x64 not supported')
def test_identity(self):
def f(x):
return x
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, in_type
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs, in_block_spec), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values, in_type)
)
# We should have no new values or scalar prefetch values.
self.assertEmpty(scalar_prefetch_values)
self.assertEmpty(value_block_specs)
self.assertEqual(in_block_spec.block_shape, (128, 128))
self.assertEqual(in_block_spec.index_map(0, 1, 2), (0, 1))
x = np.ones((128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, new_values, x),
x,
)
def test_const(self):
x = np.ones((512, 512), dtype=np.float32)
def f():
return x
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
self.assertEmpty(scalar_prefetch_values)
self.assertEqual(value_block_specs[0].block_shape, (128, 128))
self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (0, 1))
x_block = np.ones((128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn(
(0, 0, 0),
scalar_prefetch_values,
(np.ones((128, 128), dtype=np.float32),),
),
x_block,
)
@parameterized.parameters([jnp.exp, jnp.tanh])
def test_elementwise(self, fn):
def f(x):
return fn(x)
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, in_type
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs, in_block_spec), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(
0
),
)(new_values, in_type)
)
self.assertEmpty(value_block_specs)
self.assertEqual(in_block_spec.block_shape, (128, 128))
self.assertEqual(in_block_spec.index_map(0, 1, 2), (0, 1))
x = np.ones((128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (), x),
fn(x),
)
@parameterized.parameters([jnp.exp, jnp.tanh])
def test_elementwise_bias(self, fn):
b = np.ones((512, 512), dtype=np.float32)
def f(x):
return fn(x) + b
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, in_type
)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs, in_block_spec), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values, in_type)
)
self.assertLen(value_block_specs, 1)
self.assertEqual(value_block_specs[0].block_shape, (128, 128))
self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (0, 1))
self.assertEqual(in_block_spec.block_shape, (128, 128))
self.assertEqual(in_block_spec.index_map(0, 1, 2), (0, 1))
x = np.ones((128, 128), dtype=np.float32)
b = np.ones((128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (b,), x),
fn(x) + b,
)
@parameterized.product(
fn=[lax.mul, lax.add, lax.sub, lax.div, lax.max, lax.lt, lax.eq, lax.gt],
)
def test_binop(self, fn):
def f(x, y):
return fn(x, y)
in_type = (
jax.ShapeDtypeStruct((512, 512), jnp.float32),
jax.ShapeDtypeStruct((512, 512), jnp.float32),
)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, *in_type
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs, *in_block_specs), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values, *in_type)
)
self.assertEmpty(value_block_specs)
self.assertLen(in_block_specs, 2)
x_block_spec, y_block_spec = in_block_specs
self.assertEqual(x_block_spec.block_shape, (128, 128))
self.assertEqual(
x_block_spec.index_map(0, 1, 2), block_spec.index_map(0, 1, 2)
)
self.assertEqual(y_block_spec.block_shape, (128, 128))
self.assertEqual(
y_block_spec.index_map(0, 1, 2), block_spec.index_map(0, 1, 2)
)
x = np.ones((128, 128), dtype=np.float32)
y = np.ones((128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, new_values, x, y),
fn(x, y),
)
def test_slice(self):
x = jax.random.normal(jax.random.key(0), (4, 512, 512), dtype=np.float32)
def f():
return jax.lax.slice(x, (1, 0, 0), (2, 512, 512))
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((1, 128, 128), lambda i, j, k: (0, i, j))
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(0),
)(new_values)
self.assertLen(value_block_specs, 1)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.block_shape, (1, 128, 128))
self.assertEqual(x_block_spec.index_map(4, 2, 3), (1, 4, 2))
x = np.ones((1, 128, 128), dtype=np.float32)
# Slice doesn't change value after pulling block spec.
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
)
def test_squeeze(self):
x = jax.random.normal(jax.random.key(0), (1, 512, 512), dtype=np.float32)
def f():
return jnp.squeeze(x, axis=0)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.block_shape, (None, 128, 128))
self.assertEqual(x_block_spec.index_map(4, 2, 3), (0, 4, 2))
x = np.ones((128, 128), dtype=np.float32)
# Squeeze doesn't change value after pulling block spec and the squeezed
# dimensions are removed.
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
)
def test_dynamic_slice_only(self):
x = jax.random.normal(jax.random.key(0), (3, 4, 512, 512), dtype=np.float32)
i = jnp.array(1, dtype=jnp.int32)
j = jnp.array(2, dtype=jnp.int32)
def f():
return jax.lax.dynamic_slice(x, (i, j, 0, 0), (1, 1, 512, 512))
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertLen(scalar_prefetch_values, 2)
block_spec = pl.BlockSpec(
(1, 1, 128, 128), lambda i, j, k, *_: (0, 0, i, j)
)
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
scalar_prefetch_values = jax.tree.map(
lambda x: x[None], scalar_prefetch_values
)
self.assertLen(value_block_specs, 1)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.block_shape, (1, 1, 128, 128))
self.assertEqual(
x_block_spec.index_map(0, 1, 2, *scalar_prefetch_values), (1, 2, 0, 1)
)
x = np.ones((1, 1, 128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
)
def test_dynamic_slice_squeeze(self):
x = jax.random.normal(jax.random.key(0), (3, 4, 512, 512), dtype=np.float32)
i = jnp.array(1, dtype=jnp.int32)
j = jnp.array(2, dtype=jnp.int32)
def f():
return jnp.squeeze(
jax.lax.dynamic_slice(x, (i, j, 0, 0), (1, 1, 512, 512)), axis=(0, 1)
)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertLen(scalar_prefetch_values, 2)
scalar_prefetch_values = jax.tree.map(
lambda x: x[None], scalar_prefetch_values
)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k, *_: (i, j))
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.block_shape, (None, None, 128, 128))
self.assertEqual(
x_block_spec.index_map(0, 1, 2, *scalar_prefetch_values), (1, 2, 0, 1)
)
x = np.ones((128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
)
def test_concatenate_spanning_blocks(self):
x = jax.random.normal(jax.random.key(0), (256, 256), dtype=np.float32)
y = jax.random.normal(jax.random.key(1), (256, 256), dtype=np.float32)
def f(x, y):
return jnp.concatenate([x, y], axis=0)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, x, y
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, *_: (i, j))
kernel_fn, (value_block_specs, *in_block_specs), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(4, 2),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(
0
),
)(new_values, x, y)
)
self.assertEmpty(value_block_specs)
self.assertLen(in_block_specs, 2)
x_block_spec, y_block_spec = in_block_specs
self.assertEqual(x_block_spec.block_shape, (128, 128))
self.assertEqual(y_block_spec.block_shape, (128, 128))
# Indices in this concatenate should be clamped depending on if the
# block is OOB.
self.assertEqual(
x_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
)
self.assertEqual(
x_block_spec.index_map(1, 0, scalar_prefetch_values), (1, 0)
)
self.assertEqual(
x_block_spec.index_map(2, 0, scalar_prefetch_values), (1, 0)
)
self.assertEqual(
x_block_spec.index_map(3, 0, scalar_prefetch_values), (1, 0)
)
self.assertEqual(
y_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
)
self.assertEqual(
y_block_spec.index_map(1, 0, scalar_prefetch_values), (0, 0)
)
self.assertEqual(
y_block_spec.index_map(2, 0, scalar_prefetch_values), (0, 0)
)
self.assertEqual(
y_block_spec.index_map(3, 0, scalar_prefetch_values), (1, 0)
)
x = jax.random.normal(jax.random.key(0), (128, 128), dtype=np.float32)
y = jax.random.normal(jax.random.key(1), (128, 128), dtype=np.float32)
# Evaluating should just select the valid block.
np.testing.assert_array_equal(
kernel_fn((0, 0), scalar_prefetch_values, (), x, y), x
)
np.testing.assert_array_equal(
kernel_fn((1, 0), scalar_prefetch_values, (), x, y), x
)
np.testing.assert_array_equal(
kernel_fn((2, 0), scalar_prefetch_values, (), x, y), y
)
np.testing.assert_array_equal(
kernel_fn((3, 0), scalar_prefetch_values, (), x, y), y
)
def test_concatenate_full_blocks(self):
x = jax.random.normal(jax.random.key(0), (256, 256), dtype=np.float32)
y = jax.random.normal(jax.random.key(1), (256, 256), dtype=np.float32)
def f(x, y):
return jnp.concatenate([x, y], axis=0)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, x, y
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((512, 128), lambda i, j, *_: (i, j))
kernel_fn, (value_block_specs, *in_block_specs), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 2),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(
0
),
)(new_values, x, y)
)
self.assertEmpty(value_block_specs)
self.assertLen(in_block_specs, 2)
x_block_spec, y_block_spec = in_block_specs
self.assertEqual(x_block_spec.block_shape, (256, 128))
self.assertEqual(y_block_spec.block_shape, (256, 128))
# Indices in this concatenate always just fetch the whole block on that
# dimension.
self.assertEqual(
x_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
)
self.assertEqual(
x_block_spec.index_map(0, 1, scalar_prefetch_values), (0, 1)
)
self.assertEqual(
y_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
)
self.assertEqual(
y_block_spec.index_map(0, 1, scalar_prefetch_values), (0, 1)
)
x = jax.random.normal(jax.random.key(0), (256, 128), dtype=np.float32)
y = jax.random.normal(jax.random.key(1), (256, 128), dtype=np.float32)
xy = jnp.concatenate([x, y], axis=0)
# Evaluating should just select the valid block.
np.testing.assert_array_equal(
kernel_fn((0, 0), scalar_prefetch_values, (), x, y), xy
)
np.testing.assert_array_equal(
kernel_fn((1, 0), scalar_prefetch_values, (), x, y), xy
)
def test_transpose_minor(self):
x = jax.random.normal(jax.random.key(0), (512, 256), dtype=np.float32)
def f():
return jax.lax.transpose(x, (1, 0))
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.block_shape, (128, 128))
self.assertEqual(
x_block_spec.index_map(0, 1, 2, *scalar_prefetch_values), (1, 0)
)
x = jax.random.normal(jax.random.key(0), (128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x.T
)
def test_transpose_major(self):
x = jax.random.normal(jax.random.key(0), (2, 3, 512, 256), dtype=np.float32)
def f():
return jax.lax.transpose(x, (1, 0, 2, 3))
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec(
(None, None, 128, 128), lambda i, j, k, l: (i, j, k, l)
)
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(1, 1, 1, 1),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.block_shape, (None, None, 128, 128))
self.assertEqual(
x_block_spec.index_map(0, 1, 2, 3, *scalar_prefetch_values),
(1, 0, 2, 3),
)
x = jax.random.normal(jax.random.key(0), (128, 128), dtype=np.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x
)
def test_iota(self):
def f():
return jax.lax.broadcasted_iota(jnp.int32, (2, 2, 512, 512), 2)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec(
(None, None, 128, 128), lambda i, j, k, l: (i, j, k, l)
)
kernel_fn, ((),), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
x = jax.lax.broadcasted_iota(jnp.int32, (128, 128), 0)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, ()), x + 128
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x + 128 * 3
)
def test_broadcast_scalar(self):
def f():
return jnp.full((1, 1, 512, 512), fill_value=1.2345, dtype=jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec(
(None, None, 128, 128), lambda i, j, k, l: (i, j, k, l)
)
kernel_fn, ((),), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x
)
def test_broadcast_scalar_with_prefetch(self):
a = jnp.array(1.2345)
def f():
return jnp.full((1, 1, 512, 512), fill_value=a, dtype=jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertEmpty(new_values)
self.assertLen(scalar_prefetch_values, 1)
block_spec = pl.BlockSpec(
(None, None, 128, 128), lambda i, j, k, l, *_: (i, j, k, l)
)
kernel_fn, ((),), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
scalar_prefetch_values = jax.tree.map(
lambda x: x[None], scalar_prefetch_values
)
x = jnp.full((128, 128), fill_value=a, dtype=jnp.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, ()), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x
)
def test_broadcast_array(self):
x = jnp.ones((512, 512))
def f():
return jax.lax.broadcast_in_dim(x, (2, 2, 512, 512), (2, 3))
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec(
(None, 1, 128, 128), lambda i, j, k, l: (i, j, k, l)
)
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
x_block_spec = value_block_specs[0]
self.assertEqual(x_block_spec.index_map(0, 0, 1, 2), (1, 2))
self.assertEqual(x_block_spec.index_map(1, 2, 3, 3), (3, 3))
x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x
)
np.testing.assert_array_equal(
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, (x,)), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, (x,)), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, (x,)), x
)
np.testing.assert_array_equal(
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, (x,)), x
)
class PullBlockSpecHOPTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if config.enable_x64.value:
self.skipTest('x64 not supported')
def test_jit(self):
def f(x):
return jax.jit(jnp.sin)(x)
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, in_type
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec(
(None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l)
)
kernel_fn, (value_block_specs, *in_block_specs), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values, in_type)
)
self.assertEmpty(value_block_specs)
x_block_spec = in_block_specs[0]
self.assertEqual(x_block_spec.index_map(0, 0, 1, 2, ()), (0, 0, 1, 2))
self.assertEqual(x_block_spec.index_map(1, 2, 3, 3, ()), (1, 2, 3, 3))
x = jax.random.normal(jax.random.key(0), (1, 128, 128), dtype=np.float32)
sin_x = jnp.sin(x)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), sin_x
)
def test_custom_jvp(self):
def f(x):
return jax.nn.relu(x)
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
f, in_type
)
self.assertEmpty(new_values)
self.assertEmpty(scalar_prefetch_values)
block_spec = pl.BlockSpec(
(None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l)
)
kernel_fn, (value_block_specs, *in_block_specs), _ = (
block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values, in_type)
)
self.assertEmpty(value_block_specs)
x_block_spec = in_block_specs[0]
self.assertEqual(x_block_spec.index_map(0, 0, 1, 2, ()), (0, 0, 1, 2))
self.assertEqual(x_block_spec.index_map(1, 2, 3, 3, ()), (1, 2, 3, 3))
x = jax.random.normal(jax.random.key(0), (1, 128, 128), dtype=np.float32)
relu_x = jax.nn.relu(x)
np.testing.assert_array_equal(
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

File diff suppressed because it is too large Load Diff