mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
This commit is contained in:
parent
2c7043f63d
commit
0b6c355083
121
jax/_src/pallas/fuser/BUILD
Normal file
121
jax/_src/pallas/fuser/BUILD
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fuser",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
deps = [
|
||||
":block_spec",
|
||||
":fusable",
|
||||
":fusion",
|
||||
":jaxpr_fusion",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "block_spec",
|
||||
srcs = [
|
||||
"block_spec.py",
|
||||
],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:ad_util",
|
||||
"//jax:api_util",
|
||||
"//jax:core",
|
||||
"//jax:partial_eval",
|
||||
"//jax:tree_util",
|
||||
"//jax:util",
|
||||
"//jax/_src/pallas",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fusable",
|
||||
srcs = [
|
||||
"fusable.py",
|
||||
],
|
||||
deps = [
|
||||
":fusion",
|
||||
"//jax",
|
||||
"//jax:api_util",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:partial_eval",
|
||||
"//jax:tree_util",
|
||||
"//jax:util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fusion",
|
||||
srcs = [
|
||||
"fusion.py",
|
||||
],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "jaxpr_fusion",
|
||||
srcs = [
|
||||
"jaxpr_fusion.py",
|
||||
],
|
||||
deps = [
|
||||
":fusable",
|
||||
":fusable_dtype",
|
||||
":fusion",
|
||||
"//jax",
|
||||
"//jax:api_util",
|
||||
"//jax:core",
|
||||
"//jax:partial_eval",
|
||||
"//jax:tree_util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fusable_dtype",
|
||||
srcs = [
|
||||
"fusable_dtype.py",
|
||||
],
|
||||
deps = [
|
||||
":block_spec",
|
||||
":fusable",
|
||||
"//jax",
|
||||
"//jax:api_util",
|
||||
"//jax:core",
|
||||
"//jax:dtypes",
|
||||
"//jax:partial_eval",
|
||||
"//jax:source_info_util",
|
||||
"//jax:tree_util",
|
||||
"//jax:util",
|
||||
"//jax/_src/pallas",
|
||||
],
|
||||
)
|
21
jax/_src/pallas/fuser/__init__.py
Normal file
21
jax/_src/pallas/fuser/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
|
||||
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.fusable import fusable as fusable
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
1513
jax/_src/pallas/fuser/block_spec.py
Normal file
1513
jax/_src/pallas/fuser/block_spec.py
Normal file
File diff suppressed because it is too large
Load Diff
83
jax/_src/pallas/fuser/fusable.py
Normal file
83
jax/_src/pallas/fuser/fusable.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fusable primitive."""
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas.fuser import fusion as fusion_lib
|
||||
|
||||
fusable_p = jax_core.Primitive('fusable')
|
||||
fusable_p.multiple_results = True
|
||||
|
||||
|
||||
def _get_aval(x):
|
||||
return jax_core.raise_to_shaped(jax_core.get_aval(x))
|
||||
|
||||
|
||||
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
|
||||
return fusion_lib.Fusion(
|
||||
func=lambda: x,
|
||||
in_type=((), {}),
|
||||
out_type=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def fusable(f):
|
||||
def wrapper(*args):
|
||||
def wrapped(*args):
|
||||
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
|
||||
return f(*in_fusions, None)
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
debug_info = api_util.debug_info('fusable', wrapped, args, {})
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [_get_aval(x) for x in flat_args]
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
out = fusable_p.bind(
|
||||
*consts,
|
||||
*flat_args,
|
||||
jaxpr=jaxpr,
|
||||
num_consts=len(consts),
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree,
|
||||
func=f,
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree, out)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@fusable_p.def_impl
|
||||
def _(*consts_and_args, jaxpr, num_consts, **_):
|
||||
consts, args = util.split_list(consts_and_args, [num_consts])
|
||||
return jax_core.eval_jaxpr(jaxpr, consts, *args)
|
||||
|
||||
|
||||
mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl))
|
||||
|
||||
|
||||
@fusable_p.def_abstract_eval
|
||||
def _(*args, jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
return [v.aval for v in jaxpr.outvars]
|
465
jax/_src/pallas/fuser/fusable_dtype.py
Normal file
465
jax/_src/pallas/fuser/fusable_dtype.py
Normal file
@ -0,0 +1,465 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom fusable dtypes."""
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Sequence, TypeVar
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax.control_flow import conditionals
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives as pallas_primitives
|
||||
from jax._src.pallas.fuser import block_spec
|
||||
from jax._src.pallas.fuser.fusable import fusable_p
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as state_primitives
|
||||
|
||||
# TODO(sharadmv): Enable type checking.
|
||||
# mypy: ignore-errors
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_physicalize_rules = {}
|
||||
|
||||
pack_dtype_p = core.Primitive("pack_dtype")
|
||||
|
||||
|
||||
@pack_dtype_p.def_abstract_eval
|
||||
def pack_dtype_abstract_eval(*xs, dtype):
|
||||
if dtypes.issubdtype(dtype, FusableElementDType):
|
||||
return dtype.abstract_pack(*xs)
|
||||
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
|
||||
|
||||
|
||||
def pack(*xs, dtype):
|
||||
return pack_dtype_p.bind(*xs, dtype=dtype)
|
||||
|
||||
|
||||
unpack_dtype_p = core.Primitive("unpack_dtype")
|
||||
unpack_dtype_p.multiple_results = True
|
||||
|
||||
|
||||
@unpack_dtype_p.def_abstract_eval
|
||||
def unpack_dtype_abstract_eval(x):
|
||||
if dtypes.issubdtype(x.dtype, FusableElementDType):
|
||||
return x.dtype.abstract_unpack(x)
|
||||
elif isinstance(x.dtype, pallas_core.AbstractMemoryRef):
|
||||
raise NotImplementedError()
|
||||
raise ValueError("Attempted to unpack non-fusion dtype: {dtype}")
|
||||
|
||||
|
||||
def unpack(x):
|
||||
return unpack_dtype_p.bind(x)
|
||||
|
||||
|
||||
class FusableElementDType(dtypes.extended):
|
||||
"""Scalar dtype for fusable dtypes."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FusableTyRules:
|
||||
allow_conversion: bool = False
|
||||
|
||||
|
||||
class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
|
||||
"""Base class for fusable extended dtypes."""
|
||||
|
||||
_op_registry = {}
|
||||
_rules = FusableTyRules
|
||||
type = FusableElementDType
|
||||
|
||||
@abc.abstractmethod
|
||||
def abstract_unpack(self, x) -> Sequence[Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def abstract_pack(self, *xs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def register_op(cls, primitive):
|
||||
def _register_fn(fn):
|
||||
cls._op_registry[primitive] = fn
|
||||
|
||||
return _register_fn
|
||||
|
||||
@classmethod
|
||||
def get_op_rule(cls, primitive):
|
||||
return cls._op_registry.get(primitive)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return str(self)
|
||||
|
||||
@abc.abstractmethod
|
||||
def pull_block_spec_one_step(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def physicalize(f):
|
||||
"""Runs a function that contains fusable extended dtypes."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if kwargs:
|
||||
raise NotImplementedError()
|
||||
flattened_args, treedef = jax.tree.flatten(args)
|
||||
debug_info = api_util.debug_info("physicalize", f, args, kwargs)
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(f, debug_info=debug_info), treedef
|
||||
)
|
||||
avals = [core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
|
||||
out_flat = core.eval_jaxpr(
|
||||
new_jaxpr.jaxpr, new_jaxpr.consts, *flattened_args
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def physicalize_closed_jaxpr(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
|
||||
"""Replaces all extended dtypes with physical types in a jaxpr."""
|
||||
fun = functools.partial(physicalize_interp, jaxpr.jaxpr, jaxpr.consts)
|
||||
in_avals = [_physical_aval(aval) for aval in jaxpr.in_avals]
|
||||
flat_avals, treedef = tree_util.tree_flatten(in_avals)
|
||||
debug_info = api_util.debug_info("physicalize_closed_jaxpr", fun, (), {})
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug_info), treedef
|
||||
)
|
||||
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
|
||||
assert len(new_jaxpr.constvars) == len(consts), "Mismatched consts"
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
|
||||
def _physical_aval(aval):
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return core.ShapedArray(aval.shape, aval.dtype)
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
if isinstance(aval.dtype, FusionDType):
|
||||
unpacked = aval.dtype.abstract_unpack(aval.inner_aval)
|
||||
return tuple(aval.update(inner_aval=u) for u in unpacked)
|
||||
return aval
|
||||
return aval
|
||||
|
||||
|
||||
def physicalize_jaxpr(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
"""Replaces all extended dtypes with physical types in a jaxpr."""
|
||||
|
||||
def _flat_jaxpr_eval(consts, args):
|
||||
return physicalize_interp(jaxpr, consts, *args)
|
||||
|
||||
in_avals = [_physical_aval(v.aval) for v in jaxpr.invars]
|
||||
const_avals = [_physical_aval(v.aval) for v in jaxpr.constvars]
|
||||
flat_avals, treedef = jax.tree.flatten((const_avals, in_avals))
|
||||
debug_info = api_util.debug_info(
|
||||
"physicalize_jaxpr", _flat_jaxpr_eval, (), {}
|
||||
)
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(_flat_jaxpr_eval, debug_info=debug_info), treedef
|
||||
)
|
||||
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
|
||||
assert not consts
|
||||
new_jaxpr = pe.convert_invars_to_constvars(
|
||||
new_jaxpr, len(tree_util.tree_leaves(const_avals))
|
||||
)
|
||||
return new_jaxpr
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Context:
|
||||
avals_in: Sequence[Any]
|
||||
avals_out: Sequence[Any]
|
||||
|
||||
|
||||
def physicalize_interp(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
|
||||
):
|
||||
"""Physicalizes a jaxpr by replacing fusable dtypes with physical types."""
|
||||
# TODO: Merge into JAX core.
|
||||
env: dict[core.Var, Any] = {}
|
||||
|
||||
def read_env(var: core.Atom):
|
||||
if isinstance(var, core.Literal):
|
||||
return var.val
|
||||
return env[var]
|
||||
|
||||
def write_env(var: core.Var, val: Any):
|
||||
env[var] = val
|
||||
|
||||
map(write_env, jaxpr.constvars, consts)
|
||||
assert len(jaxpr.invars) == len(
|
||||
args
|
||||
), f"Length mismatch: {jaxpr.invars} != {args}"
|
||||
map(write_env, jaxpr.invars, args)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
invals = list(map(read_env, eqn.invars))
|
||||
avals_in = tuple(x.aval for x in eqn.invars)
|
||||
name_stack = (
|
||||
source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
)
|
||||
with (
|
||||
source_info_util.user_context(
|
||||
eqn.source_info.traceback, name_stack=name_stack
|
||||
),
|
||||
eqn.ctx.manager,
|
||||
):
|
||||
# need to check types and then invoke the correct rule.
|
||||
in_types = [aval.dtype for aval in avals_in] # pytype: disable=attribute-error
|
||||
ctx = Context(
|
||||
avals_in=avals_in, avals_out=[var.aval for var in eqn.outvars]
|
||||
)
|
||||
custom_rule = _phys_find_rule(eqn.primitive, in_types)
|
||||
if custom_rule:
|
||||
outvals = custom_rule(ctx, *invals, **eqn.params)
|
||||
else:
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
|
||||
|
||||
if eqn.primitive.multiple_results:
|
||||
assert len(outvals) == len(eqn.outvars)
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
|
||||
return map(read_env, jaxpr.outvars)
|
||||
|
||||
|
||||
def _phys_find_rule(primitive, types: Sequence[dtypes.DType]):
|
||||
"""Finds the physicalization rule for a primitive."""
|
||||
if primitive in _physicalize_rules:
|
||||
return _physicalize_rules[primitive]
|
||||
fusion_types = {type_ for type_ in types if isinstance(type_, FusionDType)}
|
||||
if len(fusion_types) == 0:
|
||||
return None
|
||||
elif len(fusion_types) > 1:
|
||||
raise ValueError(f"Multiple fusion types for primitive: {fusion_types}")
|
||||
fusion_type = fusion_types.pop()
|
||||
if primitive not in fusion_type._op_registry:
|
||||
raise ValueError(
|
||||
f"No implementation found for primitive {primitive} "
|
||||
f"for custom type {fusion_type}"
|
||||
)
|
||||
return fusion_type.get_op_rule(primitive)
|
||||
|
||||
|
||||
def _assert_no_fusion_types(avals: Sequence[core.AbstractValue]):
|
||||
for aval in avals:
|
||||
if isinstance(aval, (core.ShapedArray, state.AbstractRef)):
|
||||
if isinstance(aval.dtype, FusionDType):
|
||||
raise NotImplementedError(f"Fusion type found in avals: {avals}")
|
||||
|
||||
|
||||
def _pallas_call_physicalize_rule(
|
||||
ctx: Context, *args, jaxpr, grid_mapping: pallas_core.GridMapping, **kwargs
|
||||
):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
with grid_mapping.trace_env():
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, ()))
|
||||
num_new_vals = len(new_jaxpr.jaxpr.invars) - len(jaxpr.invars)
|
||||
grid_mapping = grid_mapping.replace(
|
||||
num_scratch_operands=grid_mapping.num_scratch_operands + num_new_vals
|
||||
)
|
||||
return pallas_call.pallas_call_p.bind(
|
||||
*args, jaxpr=new_jaxpr.jaxpr, grid_mapping=grid_mapping, **kwargs
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[pallas_call.pallas_call_p] = _pallas_call_physicalize_rule
|
||||
|
||||
|
||||
def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
physicalized_branches = [
|
||||
physicalize_closed_jaxpr(branch) for branch in branches
|
||||
]
|
||||
flat_args = jax.tree.leaves(args)
|
||||
return conditionals.cond_p.bind(
|
||||
*flat_args, branches=physicalized_branches, **kwargs
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule
|
||||
|
||||
|
||||
def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
jaxpr = physicalize_jaxpr(jaxpr)
|
||||
return state_discharge.run_state_p.bind(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
which_linear=which_linear,
|
||||
is_initialized=is_initialized,
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[state_discharge.run_state_p] = _run_state_rule
|
||||
|
||||
|
||||
def _core_map_rule(ctx: Context, *args, jaxpr, **params):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
assert not jaxpr.invars
|
||||
with core.extend_axis_env_nd(params["mesh"].shape.items()):
|
||||
jaxpr = physicalize_jaxpr(jaxpr)
|
||||
return pallas_core.core_map_p.bind(*args, jaxpr=jaxpr, **params)
|
||||
|
||||
|
||||
_physicalize_rules[pallas_core.core_map_p] = _core_map_rule
|
||||
|
||||
|
||||
def _run_scoped_rule(ctx: Context, *args, jaxpr, **params):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
jaxpr = physicalize_jaxpr(jaxpr)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
assert len(flat_args) == len(
|
||||
jaxpr.constvars
|
||||
), f"Length mismatch: {len(flat_args)=} != {len(jaxpr.constvars)=}"
|
||||
return pallas_primitives.run_scoped_p.bind(*flat_args, jaxpr=jaxpr, **params)
|
||||
|
||||
|
||||
_physicalize_rules[pallas_primitives.run_scoped_p] = _run_scoped_rule
|
||||
|
||||
|
||||
def _scan_rule(ctx: Context, *args, jaxpr, **params):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
jaxpr = physicalize_closed_jaxpr(jaxpr)
|
||||
return jax.lax.scan_p.bind(*args, jaxpr=jaxpr, **params)
|
||||
|
||||
|
||||
_physicalize_rules[jax.lax.scan_p] = _scan_rule
|
||||
|
||||
|
||||
def _while_rule(
|
||||
ctx: Context, *args, body_jaxpr, cond_jaxpr, body_nconsts, **params
|
||||
):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
cond_avals = [v.aval for v in cond_jaxpr.jaxpr.invars]
|
||||
_assert_no_fusion_types(cond_avals)
|
||||
body_avals = [v.aval for v in body_jaxpr.jaxpr.invars]
|
||||
_, body_in_avals = util.split_list(body_avals, [body_nconsts])
|
||||
_assert_no_fusion_types(body_in_avals)
|
||||
new_body_jaxpr = physicalize_closed_jaxpr(body_jaxpr)
|
||||
new_num_body_consts = (
|
||||
body_nconsts
|
||||
+ len(new_body_jaxpr.jaxpr.invars)
|
||||
- len(body_jaxpr.jaxpr.invars)
|
||||
)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
assert len(flat_args) == len(new_body_jaxpr.jaxpr.invars), (
|
||||
f"Length mismatch: {len(flat_args)=} !="
|
||||
f" {len(new_body_jaxpr.jaxpr.invars)=}"
|
||||
)
|
||||
return jax.lax.while_p.bind(
|
||||
*flat_args,
|
||||
body_jaxpr=new_body_jaxpr,
|
||||
cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=new_num_body_consts,
|
||||
**params,
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[jax.lax.while_p] = _while_rule
|
||||
|
||||
|
||||
def _pack_rule(_, *args, dtype):
|
||||
del dtype
|
||||
return args
|
||||
|
||||
|
||||
_physicalize_rules[pack_dtype_p] = _pack_rule
|
||||
|
||||
|
||||
def _unpack_rule(_, arg):
|
||||
return arg
|
||||
|
||||
|
||||
_physicalize_rules[unpack_dtype_p] = _unpack_rule
|
||||
|
||||
|
||||
def _swap_rule(ctx: Context, ref, val, *args, tree):
|
||||
ref_aval, *_ = ctx.avals_in
|
||||
if not isinstance(ref_aval.dtype, FusionDType):
|
||||
return state_primitives.swap_p.bind(ref, val, *args, tree=tree)
|
||||
return ref_aval.dtype.swap(ref, val, *args, tree=tree)
|
||||
|
||||
|
||||
_physicalize_rules[state_primitives.swap_p] = _swap_rule
|
||||
|
||||
|
||||
def _get_rule(ctx: Context, ref, *args, tree):
|
||||
ref_aval, *_ = ctx.avals_in
|
||||
if not isinstance(ref_aval.dtype, FusionDType):
|
||||
return state_primitives.get_p.bind(ref, *args, tree=tree)
|
||||
return ref_aval.dtype.get(ref, *args, tree=tree)
|
||||
|
||||
|
||||
_physicalize_rules[state_primitives.get_p] = _get_rule
|
||||
|
||||
|
||||
@block_spec.register_eval_rule(pack_dtype_p)
|
||||
def _pack_dtype_eval_rule(eval_ctx: block_spec.KernelEvalContext, *args, dtype):
|
||||
del eval_ctx
|
||||
return pack_dtype_p.bind(*args, dtype=dtype)
|
||||
|
||||
|
||||
@block_spec.register_pull_block_spec_rule(pack_dtype_p)
|
||||
def _pack_dtype_pull_rule(
|
||||
ctx: block_spec.PullRuleContext,
|
||||
block_spec: pallas_core.BlockSpec,
|
||||
*,
|
||||
dtype: FusionDType,
|
||||
):
|
||||
del ctx
|
||||
return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error
|
||||
|
||||
|
||||
def _fusable_physicalize_rule(
|
||||
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
|
||||
):
|
||||
consts, _ = util.split_list(consts_and_args, [num_consts])
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
|
||||
return fusable_p.bind(
|
||||
*consts_and_args,
|
||||
jaxpr=new_jaxpr.jaxpr,
|
||||
num_consts=num_consts,
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree,
|
||||
func=func,
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[fusable_p] = _fusable_physicalize_rule
|
59
jax/_src/pallas/fuser/fusion.py
Normal file
59
jax/_src/pallas/fuser/fusion.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fusion classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Generic, ParamSpec, TypeVar
|
||||
|
||||
import jax
|
||||
from jax._src import util
|
||||
|
||||
safe_map = util.safe_map
|
||||
|
||||
A = ParamSpec("A")
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Fusion(Generic[A, K]):
|
||||
|
||||
func: Callable[A, K]
|
||||
in_type: tuple[tuple[Any, ...], dict[str, Any]]
|
||||
out_type: Any
|
||||
|
||||
def __call__(self, *args: A.args, **kwargs: A.kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return jax.tree.map(lambda x: x.shape, self.out_type)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return jax.tree.map(lambda x: x.dtype, self.out_type)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.out_type
|
||||
|
||||
@property
|
||||
def in_shape(self):
|
||||
return jax.tree.map(lambda x: x.shape, self.in_type)
|
||||
|
||||
@property
|
||||
def in_dtype(self):
|
||||
return jax.tree.map(lambda x: x.dtype, self.in_type)
|
147
jax/_src/pallas/fuser/jaxpr_fusion.py
Normal file
147
jax/_src/pallas/fuser/jaxpr_fusion.py
Normal file
@ -0,0 +1,147 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fuses a function."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src.pallas.fuser import fusable_dtype
|
||||
from jax._src.pallas.fuser import fusion as fusion_lib
|
||||
from jax._src.pallas.fuser.fusable import fusable_p
|
||||
|
||||
|
||||
def _get_aval(x):
|
||||
return jax_core.raise_to_shaped(jax_core.get_aval(x))
|
||||
|
||||
|
||||
def fuse(f, *, physicalize: bool = False):
|
||||
"""Fuses a function into a single fusable.
|
||||
|
||||
There should be a single call to a `fusable` inside the body of `f`. `fuse`
|
||||
returns a transformed function that will fuse the surrounding computation into
|
||||
the fusable and invoke it.
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
debug_info = api_util.debug_info('fuse', f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [_get_aval(x) for x in flat_args]
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
if physicalize:
|
||||
wrapper = fusable_dtype.physicalize(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
_fusable: dict[jax_core.Primitive, Any] = {}
|
||||
|
||||
|
||||
def construct_fusion(
|
||||
candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs
|
||||
) -> fusion_lib.Fusion:
|
||||
flat_outvars, out_tree = tree_util.tree_flatten(outvars)
|
||||
flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs))
|
||||
new_jaxpr_no_dce = jaxpr.replace(
|
||||
outvars=flat_outvars,
|
||||
constvars=jaxpr.constvars + jaxpr.invars,
|
||||
invars=flat_invars,
|
||||
)
|
||||
new_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
||||
new_jaxpr_no_dce,
|
||||
[True] * len(new_jaxpr_no_dce.outvars),
|
||||
instantiate=[False] * len(new_jaxpr_no_dce.constvars)
|
||||
+ [True] * len(new_jaxpr_no_dce.invars),
|
||||
)
|
||||
assert all(used_invars), new_jaxpr_no_dce
|
||||
new_values = tuple(
|
||||
c for used, c in zip(used_consts, candidate_values, strict=True) if used
|
||||
)
|
||||
kernel_in_tree = tree_util.tree_structure((invars, kwargs))
|
||||
|
||||
def _fn(*args, **kwargs):
|
||||
flat_args, _ = tree_util.tree_flatten((args, kwargs))
|
||||
out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
flat_in_type = [
|
||||
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars
|
||||
]
|
||||
in_type = tree_util.tree_unflatten(kernel_in_tree, flat_in_type)
|
||||
out_type = tree_util.tree_unflatten(
|
||||
out_tree,
|
||||
[jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars],
|
||||
)
|
||||
return fusion_lib.Fusion(_fn, in_type, out_type)
|
||||
|
||||
|
||||
def fuse_jaxpr(
|
||||
jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args
|
||||
):
|
||||
fusion_eqn_index = None
|
||||
|
||||
# Collect input fusions
|
||||
for i, eqn in enumerate(jaxpr.eqns):
|
||||
if eqn.primitive is fusable_p:
|
||||
fusion_eqn_index = i
|
||||
break
|
||||
if fusion_eqn_index is None:
|
||||
raise ValueError("No fusable eqn found")
|
||||
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
|
||||
|
||||
candidate_values = [*consts, *args]
|
||||
|
||||
# Construct fusions for non-constant inputs to the fusable.
|
||||
in_fusions_flat = [
|
||||
construct_fusion(
|
||||
candidate_values,
|
||||
jaxpr.replace(
|
||||
eqns=jaxpr.eqns[:fusion_eqn_index],
|
||||
),
|
||||
var,
|
||||
)
|
||||
for var in fusion_eqn.invars[fusion_eqn.params["num_consts"] :]
|
||||
]
|
||||
in_fusions = tree_util.tree_unflatten(
|
||||
fusion_eqn.params["in_tree"], in_fusions_flat
|
||||
)
|
||||
out_fusion = construct_fusion(
|
||||
candidate_values,
|
||||
jaxpr.replace(
|
||||
eqns=jaxpr.eqns[:fusion_eqn_index]
|
||||
+ jaxpr.eqns[fusion_eqn_index + 1 :]
|
||||
),
|
||||
tree_util.tree_unflatten(out_tree, jaxpr.outvars),
|
||||
tree_util.tree_unflatten(
|
||||
fusion_eqn.params["out_tree"], fusion_eqn.outvars
|
||||
),
|
||||
)
|
||||
# Run the fusable.
|
||||
out = fusion_eqn.params["func"](*in_fusions, out_fusion)
|
||||
|
||||
# Now return the flattened output (the fuse_jaxpr caller should unflatten).
|
||||
out_flat = tree_util.tree_leaves(out)
|
||||
assert len(out_flat) == len(jaxpr.outvars)
|
||||
return out_flat
|
@ -589,3 +589,60 @@ jax_multiplatform_test(
|
||||
"//jax:pallas_mosaic_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "fuser_block_spec_test",
|
||||
srcs = [
|
||||
"fuser_block_spec_test.py",
|
||||
],
|
||||
disable_configs = [
|
||||
"cpu",
|
||||
"cpu_shardy",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
tags = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax/_src/pallas/fuser",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_fusable_matmul_test",
|
||||
srcs = ["tpu_fusable_matmul_test.py"],
|
||||
disable_configs = [
|
||||
"tpu_v3_1x1",
|
||||
"tpu_pjrt_c_api",
|
||||
"gpu_v100",
|
||||
"gpu_v100_x32",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_h100",
|
||||
"cpu",
|
||||
"cpu_x32",
|
||||
"cpu_shardy",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
enable_configs = [
|
||||
"tpu_v4_1x1",
|
||||
"tpu_v5e",
|
||||
"tpu_v5p_1x1",
|
||||
"tpu_v6e_1x1",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
"//jax:pallas_tpu",
|
||||
"//jax:pallas_tpu_ops",
|
||||
"//jax/_src/pallas/fuser",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
776
tests/pallas/fuser_block_spec_test.py
Normal file
776
tests/pallas/fuser_block_spec_test.py
Normal file
@ -0,0 +1,776 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for pull block spec."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.pallas.fuser import block_spec as block_spec_lib
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class PullBlockSpecTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if config.enable_x64.value:
|
||||
self.skipTest('x64 not supported')
|
||||
|
||||
def test_identity(self):
|
||||
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, in_type
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs, in_block_spec), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values, in_type)
|
||||
)
|
||||
# We should have no new values or scalar prefetch values.
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
self.assertEmpty(value_block_specs)
|
||||
self.assertEqual(in_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(in_block_spec.index_map(0, 1, 2), (0, 1))
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, new_values, x),
|
||||
x,
|
||||
)
|
||||
|
||||
def test_const(self):
|
||||
|
||||
x = np.ones((512, 512), dtype=np.float32)
|
||||
|
||||
def f():
|
||||
return x
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
self.assertEqual(value_block_specs[0].block_shape, (128, 128))
|
||||
self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (0, 1))
|
||||
|
||||
x_block = np.ones((128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn(
|
||||
(0, 0, 0),
|
||||
scalar_prefetch_values,
|
||||
(np.ones((128, 128), dtype=np.float32),),
|
||||
),
|
||||
x_block,
|
||||
)
|
||||
|
||||
@parameterized.parameters([jnp.exp, jnp.tanh])
|
||||
def test_elementwise(self, fn):
|
||||
|
||||
def f(x):
|
||||
return fn(x)
|
||||
|
||||
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, in_type
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs, in_block_spec), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(
|
||||
0
|
||||
),
|
||||
)(new_values, in_type)
|
||||
)
|
||||
self.assertEmpty(value_block_specs)
|
||||
self.assertEqual(in_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(in_block_spec.index_map(0, 1, 2), (0, 1))
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (), x),
|
||||
fn(x),
|
||||
)
|
||||
|
||||
@parameterized.parameters([jnp.exp, jnp.tanh])
|
||||
def test_elementwise_bias(self, fn):
|
||||
|
||||
b = np.ones((512, 512), dtype=np.float32)
|
||||
|
||||
def f(x):
|
||||
return fn(x) + b
|
||||
|
||||
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, in_type
|
||||
)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs, in_block_spec), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values, in_type)
|
||||
)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
self.assertEqual(value_block_specs[0].block_shape, (128, 128))
|
||||
self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (0, 1))
|
||||
self.assertEqual(in_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(in_block_spec.index_map(0, 1, 2), (0, 1))
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
b = np.ones((128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (b,), x),
|
||||
fn(x) + b,
|
||||
)
|
||||
|
||||
@parameterized.product(
|
||||
fn=[lax.mul, lax.add, lax.sub, lax.div, lax.max, lax.lt, lax.eq, lax.gt],
|
||||
)
|
||||
def test_binop(self, fn):
|
||||
|
||||
def f(x, y):
|
||||
return fn(x, y)
|
||||
|
||||
in_type = (
|
||||
jax.ShapeDtypeStruct((512, 512), jnp.float32),
|
||||
jax.ShapeDtypeStruct((512, 512), jnp.float32),
|
||||
)
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, *in_type
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs, *in_block_specs), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values, *in_type)
|
||||
)
|
||||
self.assertEmpty(value_block_specs)
|
||||
self.assertLen(in_block_specs, 2)
|
||||
x_block_spec, y_block_spec = in_block_specs
|
||||
self.assertEqual(x_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 1, 2), block_spec.index_map(0, 1, 2)
|
||||
)
|
||||
self.assertEqual(y_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(0, 1, 2), block_spec.index_map(0, 1, 2)
|
||||
)
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
y = np.ones((128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, new_values, x, y),
|
||||
fn(x, y),
|
||||
)
|
||||
|
||||
def test_slice(self):
|
||||
x = jax.random.normal(jax.random.key(0), (4, 512, 512), dtype=np.float32)
|
||||
|
||||
def f():
|
||||
return jax.lax.slice(x, (1, 0, 0), (2, 512, 512))
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((1, 128, 128), lambda i, j, k: (0, i, j))
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(0),
|
||||
)(new_values)
|
||||
|
||||
self.assertLen(value_block_specs, 1)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.block_shape, (1, 128, 128))
|
||||
self.assertEqual(x_block_spec.index_map(4, 2, 3), (1, 4, 2))
|
||||
|
||||
x = np.ones((1, 128, 128), dtype=np.float32)
|
||||
# Slice doesn't change value after pulling block spec.
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
|
||||
def test_squeeze(self):
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (1, 512, 512), dtype=np.float32)
|
||||
|
||||
def f():
|
||||
return jnp.squeeze(x, axis=0)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.block_shape, (None, 128, 128))
|
||||
self.assertEqual(x_block_spec.index_map(4, 2, 3), (0, 4, 2))
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
# Squeeze doesn't change value after pulling block spec and the squeezed
|
||||
# dimensions are removed.
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
|
||||
def test_dynamic_slice_only(self):
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (3, 4, 512, 512), dtype=np.float32)
|
||||
i = jnp.array(1, dtype=jnp.int32)
|
||||
j = jnp.array(2, dtype=jnp.int32)
|
||||
|
||||
def f():
|
||||
return jax.lax.dynamic_slice(x, (i, j, 0, 0), (1, 1, 512, 512))
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertLen(scalar_prefetch_values, 2)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(1, 1, 128, 128), lambda i, j, k, *_: (0, 0, i, j)
|
||||
)
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
scalar_prefetch_values = jax.tree.map(
|
||||
lambda x: x[None], scalar_prefetch_values
|
||||
)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.block_shape, (1, 1, 128, 128))
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 1, 2, *scalar_prefetch_values), (1, 2, 0, 1)
|
||||
)
|
||||
|
||||
x = np.ones((1, 1, 128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
|
||||
def test_dynamic_slice_squeeze(self):
|
||||
x = jax.random.normal(jax.random.key(0), (3, 4, 512, 512), dtype=np.float32)
|
||||
i = jnp.array(1, dtype=jnp.int32)
|
||||
j = jnp.array(2, dtype=jnp.int32)
|
||||
|
||||
def f():
|
||||
return jnp.squeeze(
|
||||
jax.lax.dynamic_slice(x, (i, j, 0, 0), (1, 1, 512, 512)), axis=(0, 1)
|
||||
)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertLen(scalar_prefetch_values, 2)
|
||||
scalar_prefetch_values = jax.tree.map(
|
||||
lambda x: x[None], scalar_prefetch_values
|
||||
)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k, *_: (i, j))
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.block_shape, (None, None, 128, 128))
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 1, 2, *scalar_prefetch_values), (1, 2, 0, 1)
|
||||
)
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
|
||||
def test_concatenate_spanning_blocks(self):
|
||||
x = jax.random.normal(jax.random.key(0), (256, 256), dtype=np.float32)
|
||||
y = jax.random.normal(jax.random.key(1), (256, 256), dtype=np.float32)
|
||||
|
||||
def f(x, y):
|
||||
return jnp.concatenate([x, y], axis=0)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, x, y
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, *_: (i, j))
|
||||
kernel_fn, (value_block_specs, *in_block_specs), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(4, 2),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(
|
||||
0
|
||||
),
|
||||
)(new_values, x, y)
|
||||
)
|
||||
self.assertEmpty(value_block_specs)
|
||||
self.assertLen(in_block_specs, 2)
|
||||
x_block_spec, y_block_spec = in_block_specs
|
||||
self.assertEqual(x_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(y_block_spec.block_shape, (128, 128))
|
||||
# Indices in this concatenate should be clamped depending on if the
|
||||
# block is OOB.
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(1, 0, scalar_prefetch_values), (1, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(2, 0, scalar_prefetch_values), (1, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(3, 0, scalar_prefetch_values), (1, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(1, 0, scalar_prefetch_values), (0, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(2, 0, scalar_prefetch_values), (0, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(3, 0, scalar_prefetch_values), (1, 0)
|
||||
)
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (128, 128), dtype=np.float32)
|
||||
y = jax.random.normal(jax.random.key(1), (128, 128), dtype=np.float32)
|
||||
# Evaluating should just select the valid block.
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0), scalar_prefetch_values, (), x, y), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((1, 0), scalar_prefetch_values, (), x, y), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((2, 0), scalar_prefetch_values, (), x, y), y
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((3, 0), scalar_prefetch_values, (), x, y), y
|
||||
)
|
||||
|
||||
def test_concatenate_full_blocks(self):
|
||||
x = jax.random.normal(jax.random.key(0), (256, 256), dtype=np.float32)
|
||||
y = jax.random.normal(jax.random.key(1), (256, 256), dtype=np.float32)
|
||||
|
||||
def f(x, y):
|
||||
return jnp.concatenate([x, y], axis=0)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, x, y
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((512, 128), lambda i, j, *_: (i, j))
|
||||
kernel_fn, (value_block_specs, *in_block_specs), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 2),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(
|
||||
0
|
||||
),
|
||||
)(new_values, x, y)
|
||||
)
|
||||
self.assertEmpty(value_block_specs)
|
||||
self.assertLen(in_block_specs, 2)
|
||||
x_block_spec, y_block_spec = in_block_specs
|
||||
self.assertEqual(x_block_spec.block_shape, (256, 128))
|
||||
self.assertEqual(y_block_spec.block_shape, (256, 128))
|
||||
# Indices in this concatenate always just fetch the whole block on that
|
||||
# dimension.
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 1, scalar_prefetch_values), (0, 1)
|
||||
)
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(0, 0, scalar_prefetch_values), (0, 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
y_block_spec.index_map(0, 1, scalar_prefetch_values), (0, 1)
|
||||
)
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (256, 128), dtype=np.float32)
|
||||
y = jax.random.normal(jax.random.key(1), (256, 128), dtype=np.float32)
|
||||
xy = jnp.concatenate([x, y], axis=0)
|
||||
# Evaluating should just select the valid block.
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0), scalar_prefetch_values, (), x, y), xy
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((1, 0), scalar_prefetch_values, (), x, y), xy
|
||||
)
|
||||
|
||||
def test_transpose_minor(self):
|
||||
x = jax.random.normal(jax.random.key(0), (512, 256), dtype=np.float32)
|
||||
|
||||
def f():
|
||||
return jax.lax.transpose(x, (1, 0))
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j))
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.block_shape, (128, 128))
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 1, 2, *scalar_prefetch_values), (1, 0)
|
||||
)
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0), scalar_prefetch_values, (x,)), x.T
|
||||
)
|
||||
|
||||
def test_transpose_major(self):
|
||||
x = jax.random.normal(jax.random.key(0), (2, 3, 512, 256), dtype=np.float32)
|
||||
|
||||
def f():
|
||||
return jax.lax.transpose(x, (1, 0, 2, 3))
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, None, 128, 128), lambda i, j, k, l: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(1, 1, 1, 1),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.block_shape, (None, None, 128, 128))
|
||||
self.assertEqual(
|
||||
x_block_spec.index_map(0, 1, 2, 3, *scalar_prefetch_values),
|
||||
(1, 0, 2, 3),
|
||||
)
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (128, 128), dtype=np.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
|
||||
def test_iota(self):
|
||||
|
||||
def f():
|
||||
return jax.lax.broadcasted_iota(jnp.int32, (2, 2, 512, 512), 2)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, None, 128, 128), lambda i, j, k, l: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, ((),), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
|
||||
x = jax.lax.broadcasted_iota(jnp.int32, (128, 128), 0)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, ()), x + 128
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x + 128 * 3
|
||||
)
|
||||
|
||||
def test_broadcast_scalar(self):
|
||||
|
||||
def f():
|
||||
return jnp.full((1, 1, 512, 512), fill_value=1.2345, dtype=jnp.float32)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, None, 128, 128), lambda i, j, k, l: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, ((),), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
|
||||
x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
|
||||
def test_broadcast_scalar_with_prefetch(self):
|
||||
|
||||
a = jnp.array(1.2345)
|
||||
|
||||
def f():
|
||||
return jnp.full((1, 1, 512, 512), fill_value=a, dtype=jnp.float32)
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertLen(scalar_prefetch_values, 1)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, None, 128, 128), lambda i, j, k, l, *_: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, ((),), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
scalar_prefetch_values = jax.tree.map(
|
||||
lambda x: x[None], scalar_prefetch_values
|
||||
)
|
||||
|
||||
x = jnp.full((128, 128), fill_value=a, dtype=jnp.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x
|
||||
)
|
||||
|
||||
def test_broadcast_array(self):
|
||||
|
||||
x = jnp.ones((512, 512))
|
||||
|
||||
def f():
|
||||
return jax.lax.broadcast_in_dim(x, (2, 2, 512, 512), (2, 3))
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, 1, 128, 128), lambda i, j, k, l: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
x_block_spec = value_block_specs[0]
|
||||
self.assertEqual(x_block_spec.index_map(0, 0, 1, 2), (1, 2))
|
||||
self.assertEqual(x_block_spec.index_map(1, 2, 3, 3), (3, 3))
|
||||
|
||||
x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((1, 1, 0, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 1), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 1, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 3, 0), scalar_prefetch_values, (x,)), x
|
||||
)
|
||||
|
||||
|
||||
class PullBlockSpecHOPTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if config.enable_x64.value:
|
||||
self.skipTest('x64 not supported')
|
||||
|
||||
def test_jit(self):
|
||||
def f(x):
|
||||
return jax.jit(jnp.sin)(x)
|
||||
|
||||
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, in_type
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, (value_block_specs, *in_block_specs), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values, in_type)
|
||||
)
|
||||
self.assertEmpty(value_block_specs)
|
||||
x_block_spec = in_block_specs[0]
|
||||
self.assertEqual(x_block_spec.index_map(0, 0, 1, 2, ()), (0, 0, 1, 2))
|
||||
self.assertEqual(x_block_spec.index_map(1, 2, 3, 3, ()), (1, 2, 3, 3))
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (1, 128, 128), dtype=np.float32)
|
||||
sin_x = jnp.sin(x)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), sin_x
|
||||
)
|
||||
|
||||
def test_custom_jvp(self):
|
||||
def f(x):
|
||||
return jax.nn.relu(x)
|
||||
|
||||
in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32)
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(
|
||||
f, in_type
|
||||
)
|
||||
self.assertEmpty(new_values)
|
||||
self.assertEmpty(scalar_prefetch_values)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, (value_block_specs, *in_block_specs), _ = (
|
||||
block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values, in_type)
|
||||
)
|
||||
self.assertEmpty(value_block_specs)
|
||||
x_block_spec = in_block_specs[0]
|
||||
self.assertEqual(x_block_spec.index_map(0, 0, 1, 2, ()), (0, 0, 1, 2))
|
||||
self.assertEqual(x_block_spec.index_map(1, 2, 3, 3, ()), (1, 2, 3, 3))
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (1, 128, 128), dtype=np.float32)
|
||||
relu_x = jax.nn.relu(x)
|
||||
np.testing.assert_array_equal(
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
1039
tests/pallas/tpu_fusable_matmul_test.py
Normal file
1039
tests/pallas/tpu_fusable_matmul_test.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user