mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Fix spelling of 'fusible'.
PiperOrigin-RevId: 747663692
This commit is contained in:
parent
0ed0fb7c54
commit
1926b99bfd
@ -721,7 +721,7 @@ pytype_strict_library(
|
||||
":pallas", # build_cleaner: keep
|
||||
"//jax/_src/pallas/fuser:block_spec",
|
||||
"//jax/_src/pallas/fuser:custom_evaluate",
|
||||
"//jax/_src/pallas/fuser:fusable",
|
||||
"//jax/_src/pallas/fuser:fusible",
|
||||
"//jax/_src/pallas/fuser:fusion",
|
||||
"//jax/_src/pallas/fuser:jaxpr_fusion",
|
||||
],
|
||||
|
@ -33,7 +33,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":block_spec",
|
||||
":custom_evaluate",
|
||||
":fusable",
|
||||
":fusible",
|
||||
":fusion",
|
||||
":jaxpr_fusion",
|
||||
],
|
||||
@ -58,9 +58,9 @@ pytype_strict_library(
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fusable",
|
||||
name = "fusible",
|
||||
srcs = [
|
||||
"fusable.py",
|
||||
"fusible.py",
|
||||
],
|
||||
deps = [
|
||||
":fusion",
|
||||
@ -91,8 +91,8 @@ pytype_strict_library(
|
||||
"jaxpr_fusion.py",
|
||||
],
|
||||
deps = [
|
||||
":fusable",
|
||||
":fusable_dtype",
|
||||
":fusible",
|
||||
":fusible_dtype",
|
||||
":fusion",
|
||||
"//jax",
|
||||
"//jax:api_util",
|
||||
@ -104,13 +104,13 @@ pytype_strict_library(
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fusable_dtype",
|
||||
name = "fusible_dtype",
|
||||
srcs = [
|
||||
"fusable_dtype.py",
|
||||
"fusible_dtype.py",
|
||||
],
|
||||
deps = [
|
||||
":block_spec",
|
||||
":fusable",
|
||||
":fusible",
|
||||
"//jax",
|
||||
"//jax:api_util",
|
||||
"//jax:core",
|
||||
|
@ -17,6 +17,6 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as mak
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
|
||||
from jax._src.pallas.fuser.fusable import fusable as fusable
|
||||
from jax._src.pallas.fuser.fusible import fusible as fusible
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fusable primitive."""
|
||||
"""Fusible primitive."""
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
@ -25,8 +25,8 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas.fuser import fusion as fusion_lib
|
||||
|
||||
fusable_p = jax_core.Primitive('fusable')
|
||||
fusable_p.multiple_results = True
|
||||
fusible_p = jax_core.Primitive('fusible')
|
||||
fusible_p.multiple_results = True
|
||||
|
||||
|
||||
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
|
||||
@ -37,7 +37,7 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
|
||||
)
|
||||
|
||||
|
||||
def fusable(f=None, *, output_fusion_prefix: Any = True):
|
||||
def fusible(f=None, *, output_fusion_prefix: Any = True):
|
||||
def decorator(f):
|
||||
def wrapper(*args):
|
||||
def wrapped(*args):
|
||||
@ -45,14 +45,14 @@ def fusable(f=None, *, output_fusion_prefix: Any = True):
|
||||
return f(*in_fusions, None)
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
debug_info = api_util.debug_info('fusable', wrapped, args, {})
|
||||
debug_info = api_util.debug_info('fusible', wrapped, args, {})
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [jax_core.get_aval(x) for x in flat_args]
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
out = fusable_p.bind(
|
||||
out = fusible_p.bind(
|
||||
*consts,
|
||||
*flat_args,
|
||||
jaxpr=jaxpr,
|
||||
@ -71,16 +71,16 @@ def fusable(f=None, *, output_fusion_prefix: Any = True):
|
||||
return decorator
|
||||
|
||||
|
||||
@fusable_p.def_impl
|
||||
@fusible_p.def_impl
|
||||
def _(*consts_and_args, jaxpr, num_consts, **_):
|
||||
consts, args = util.split_list(consts_and_args, [num_consts])
|
||||
return jax_core.eval_jaxpr(jaxpr, consts, *args)
|
||||
|
||||
|
||||
mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl))
|
||||
mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl))
|
||||
|
||||
|
||||
@fusable_p.def_abstract_eval
|
||||
@fusible_p.def_abstract_eval
|
||||
def _(*args, jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
return [v.aval for v in jaxpr.outvars]
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom fusable dtypes."""
|
||||
"""Custom fusible dtypes."""
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
@ -34,7 +34,7 @@ from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives as pallas_primitives
|
||||
from jax._src.pallas.fuser import block_spec
|
||||
from jax._src.pallas.fuser.fusable import fusable_p
|
||||
from jax._src.pallas.fuser.fusible import fusible_p
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.util import foreach
|
||||
@ -54,7 +54,7 @@ pack_dtype_p = core.Primitive("pack_dtype")
|
||||
|
||||
@pack_dtype_p.def_abstract_eval
|
||||
def pack_dtype_abstract_eval(*xs, dtype):
|
||||
if dtypes.issubdtype(dtype, FusableElementDType):
|
||||
if dtypes.issubdtype(dtype, fusibleElementDType):
|
||||
return dtype.abstract_pack(*xs)
|
||||
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
|
||||
|
||||
@ -69,7 +69,7 @@ unpack_dtype_p.multiple_results = True
|
||||
|
||||
@unpack_dtype_p.def_abstract_eval
|
||||
def unpack_dtype_abstract_eval(x):
|
||||
if dtypes.issubdtype(x.dtype, FusableElementDType):
|
||||
if dtypes.issubdtype(x.dtype, fusibleElementDType):
|
||||
return x.dtype.abstract_unpack(x)
|
||||
elif isinstance(x.dtype, pallas_core.AbstractMemoryRef):
|
||||
raise NotImplementedError()
|
||||
@ -80,20 +80,20 @@ def unpack(x):
|
||||
return unpack_dtype_p.bind(x)
|
||||
|
||||
|
||||
class FusableElementDType(dtypes.extended):
|
||||
"""Scalar dtype for fusable dtypes."""
|
||||
class fusibleElementDType(dtypes.extended):
|
||||
"""Scalar dtype for fusible dtypes."""
|
||||
|
||||
|
||||
class FusableTyRules:
|
||||
class fusibleTyRules:
|
||||
allow_conversion: bool = False
|
||||
|
||||
|
||||
class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
|
||||
"""Base class for fusable extended dtypes."""
|
||||
"""Base class for fusible extended dtypes."""
|
||||
|
||||
_op_registry = {}
|
||||
_rules = FusableTyRules
|
||||
type = FusableElementDType
|
||||
_rules = fusibleTyRules
|
||||
type = fusibleElementDType
|
||||
|
||||
@abc.abstractmethod
|
||||
def abstract_unpack(self, x) -> Sequence[Any]:
|
||||
@ -124,7 +124,7 @@ class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
def physicalize(f):
|
||||
"""Runs a function that contains fusable extended dtypes."""
|
||||
"""Runs a function that contains fusible extended dtypes."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if kwargs:
|
||||
@ -203,7 +203,7 @@ class Context:
|
||||
def physicalize_interp(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
|
||||
):
|
||||
"""Physicalizes a jaxpr by replacing fusable dtypes with physical types."""
|
||||
"""Physicalizes a jaxpr by replacing fusible dtypes with physical types."""
|
||||
# TODO: Merge into JAX core.
|
||||
env: dict[core.Var, Any] = {}
|
||||
|
||||
@ -446,12 +446,12 @@ def _pack_dtype_pull_rule(
|
||||
return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error
|
||||
|
||||
|
||||
def _fusable_physicalize_rule(
|
||||
def _fusible_physicalize_rule(
|
||||
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
|
||||
):
|
||||
consts, _ = util.split_list(consts_and_args, [num_consts])
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
|
||||
return fusable_p.bind(
|
||||
return fusible_p.bind(
|
||||
*consts_and_args,
|
||||
jaxpr=new_jaxpr.jaxpr,
|
||||
num_consts=num_consts,
|
||||
@ -461,4 +461,4 @@ def _fusable_physicalize_rule(
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[fusable_p] = _fusable_physicalize_rule
|
||||
_physicalize_rules[fusible_p] = _fusible_physicalize_rule
|
@ -23,22 +23,22 @@ from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas.fuser import fusable_dtype
|
||||
from jax._src.pallas.fuser import fusible_dtype
|
||||
from jax._src.pallas.fuser import fusion as fusion_lib
|
||||
from jax._src.pallas.fuser.fusable import fusable_p
|
||||
from jax._src.pallas.fuser.fusible import fusible_p
|
||||
|
||||
|
||||
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
|
||||
"""Fuses a function into a single fusable.
|
||||
"""Fuses a function into a single fusible.
|
||||
|
||||
Args:
|
||||
f: The function to fuse.
|
||||
physicalize: (experimental) whether to physicalize the function.
|
||||
debug: Whether to print debug information.
|
||||
|
||||
There should be a single call to a `fusable` inside the body of `f`. `fuse`
|
||||
There should be a single call to a `fusible` inside the body of `f`. `fuse`
|
||||
returns a transformed function that will fuse the surrounding computation into
|
||||
the fusable and invoke it.
|
||||
the fusible and invoke it.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
@ -58,7 +58,7 @@ def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
if physicalize:
|
||||
wrapper = fusable_dtype.physicalize(wrapper)
|
||||
wrapper = fusible_dtype.physicalize(wrapper)
|
||||
return wrapper
|
||||
|
||||
if f is not None:
|
||||
@ -66,7 +66,7 @@ def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
|
||||
return decorator
|
||||
|
||||
|
||||
_fusable: dict[jax_core.Primitive, Any] = {}
|
||||
_fusible: dict[jax_core.Primitive, Any] = {}
|
||||
|
||||
|
||||
def _construct_fusion_jaxpr(
|
||||
@ -148,11 +148,11 @@ def _construct_output_fusions(
|
||||
jaxpr,
|
||||
out_tree,
|
||||
fusion_eqn_index,
|
||||
fusion_eqn_outvars, # Flat list of vars output by the fusable eqn
|
||||
fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs
|
||||
fusion_eqn_outvars, # Flat list of vars output by the fusible eqn
|
||||
fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs
|
||||
output_fusion_prefix, # Pytree defining output groups
|
||||
):
|
||||
# 1. Create jaxpr_out: represents computation *after* the fusable
|
||||
# 1. Create jaxpr_out: represents computation *after* the fusible
|
||||
# Inputs: fusion_eqn_outvars
|
||||
# Outputs: jaxpr.outvars
|
||||
jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr(
|
||||
@ -164,15 +164,15 @@ def _construct_output_fusions(
|
||||
tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs
|
||||
tree_util.tree_unflatten(
|
||||
fusion_eqn_out_tree, fusion_eqn_outvars
|
||||
), # Fusable outputs as inputs
|
||||
), # Fusible outputs as inputs
|
||||
)
|
||||
|
||||
# 2. Group fusable outputs based on the mask
|
||||
unflat_fusable_outvars = jax.tree.unflatten(
|
||||
# 2. Group fusible outputs based on the mask
|
||||
unflat_fusible_outvars = jax.tree.unflatten(
|
||||
fusion_eqn_out_tree, fusion_eqn_outvars
|
||||
)
|
||||
partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to(
|
||||
unflat_fusable_outvars
|
||||
unflat_fusible_outvars
|
||||
)
|
||||
|
||||
# 3. Calculate dependencies and check disjointness
|
||||
@ -180,10 +180,10 @@ def _construct_output_fusions(
|
||||
already_used_final_outputs = set() # Indices of final outputs already claimed
|
||||
for outvars_group in partial_flat:
|
||||
# Identify vars in this group
|
||||
used_fusable_outvars = set(jax.tree.leaves(outvars_group))
|
||||
used_fusible_outvars = set(jax.tree.leaves(outvars_group))
|
||||
# Create mask for jaxpr_out inputs corresponding to this group
|
||||
in_used_mask = [
|
||||
True if v in used_fusable_outvars else False for v in jaxpr_out.invars
|
||||
True if v in used_fusible_outvars else False for v in jaxpr_out.invars
|
||||
]
|
||||
# Trace dependencies through jaxpr_out to find which final outputs are affected
|
||||
downstream_used_mask = _find_downstream(
|
||||
@ -257,11 +257,11 @@ def fuse_jaxpr(
|
||||
|
||||
# Collect input fusions
|
||||
for i, eqn in enumerate(jaxpr.eqns):
|
||||
if eqn.primitive is fusable_p:
|
||||
if eqn.primitive is fusible_p:
|
||||
fusion_eqn_index = i
|
||||
break
|
||||
if fusion_eqn_index is None:
|
||||
raise ValueError("No fusable eqn found")
|
||||
raise ValueError("No fusible eqn found")
|
||||
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
|
||||
|
||||
# Now let's check if we need to do any fusion at all, e.g. do the outputs of
|
||||
@ -269,13 +269,13 @@ def fuse_jaxpr(
|
||||
# with all the inputs and outputs to check if there is a dependence.
|
||||
dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars),
|
||||
instantiate=True)
|
||||
if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns):
|
||||
if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns):
|
||||
# Short circuit if there is nothing to fuse.
|
||||
return jax_core.eval_jaxpr(dced_jaxpr, consts, *args)
|
||||
|
||||
candidate_values = [*consts, *args]
|
||||
|
||||
# Construct fusions for non-constant inputs to the fusable.
|
||||
# Construct fusions for non-constant inputs to the fusible.
|
||||
in_fusions_flat = [
|
||||
construct_fusion(
|
||||
candidate_values,
|
||||
|
@ -19,6 +19,6 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as mak
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
|
||||
from jax._src.pallas.fuser.fusable import fusable as fusable
|
||||
from jax._src.pallas.fuser.fusible import fusible as fusible
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
||||
|
@ -702,8 +702,8 @@ jax_multiplatform_test(
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_fusable_matmul_test",
|
||||
srcs = ["tpu_fusable_matmul_test.py"],
|
||||
name = "tpu_fusible_matmul_test",
|
||||
srcs = ["tpu_fusible_matmul_test.py"],
|
||||
disable_configs = [
|
||||
"tpu_v3",
|
||||
"tpu_pjrt_c_api",
|
||||
|
@ -28,7 +28,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
@fuser.fusable
|
||||
@fuser.fusible
|
||||
def f(x_fn, y_fn):
|
||||
x = x_fn()
|
||||
if y_fn is None:
|
||||
@ -40,7 +40,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_separate_output_fusions_trivial(self):
|
||||
|
||||
@fuser.fusable(output_fusion_prefix=(True, True))
|
||||
@fuser.fusible(output_fusion_prefix=(True, True))
|
||||
def f(x_fn, y_fn, z_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -63,7 +63,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_separate_output_fusions_should_error_if_not_disjoint(self):
|
||||
|
||||
@fuser.fusable(output_fusion_prefix=(True, True))
|
||||
@fuser.fusible(output_fusion_prefix=(True, True))
|
||||
def f(x_fn, y_fn, z_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -89,7 +89,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_separate_output_fusions_allows_permute(self):
|
||||
|
||||
@fuser.fusable(output_fusion_prefix=(True, True))
|
||||
@fuser.fusible(output_fusion_prefix=(True, True))
|
||||
def f(x_fn, y_fn, z_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -112,7 +112,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_separate_output_fusions_with_nesting(self):
|
||||
|
||||
@fuser.fusable(output_fusion_prefix=(True, True))
|
||||
@fuser.fusible(output_fusion_prefix=(True, True))
|
||||
def f(x_fn, y_fn, z_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -136,7 +136,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_separate_output_fusions_with_nesting_and_permutation(self):
|
||||
|
||||
@fuser.fusable(output_fusion_prefix=(True, True))
|
||||
@fuser.fusible(output_fusion_prefix=(True, True))
|
||||
def f(x_fn, y_fn, z_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -160,7 +160,7 @@ class FusionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_separate_output_fusions_with_deep_output_mask(self):
|
||||
|
||||
@fuser.fusable(output_fusion_prefix=(True, (True, True)))
|
||||
@fuser.fusible(output_fusion_prefix=(True, (True, True)))
|
||||
def f(x_fn, y_fn, z_fn, o_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -185,7 +185,8 @@ class FusionTest(jtu.JaxTestCase):
|
||||
np.testing.assert_array_equal(z_out, z + z)
|
||||
|
||||
def test_separate_output_fusions_with_reused_value(self):
|
||||
@fuser.fusable(output_fusion_prefix=(True, True))
|
||||
|
||||
@fuser.fusible(output_fusion_prefix=(True, True))
|
||||
def f(x_fn, y_fn, z_fns):
|
||||
x = x_fn()
|
||||
y = y_fn()
|
||||
@ -209,7 +210,8 @@ class FusionTest(jtu.JaxTestCase):
|
||||
np.testing.assert_array_equal(y_out, y + a)
|
||||
|
||||
def test_empty_fusion(self):
|
||||
@fuser.fusable
|
||||
|
||||
@fuser.fusible
|
||||
def f(x_fn, y_fn):
|
||||
x = x_fn()
|
||||
if y_fn is None:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fusable matmul test."""
|
||||
"""Fusible matmul test."""
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
@ -75,7 +75,7 @@ def matmul_kernel(
|
||||
jax.tree.map(lambda ref, x: ref.set(x), o_ref, out)
|
||||
|
||||
|
||||
def _fusable_matmul(
|
||||
def _fusible_matmul(
|
||||
x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation
|
||||
y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation
|
||||
z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation
|
||||
@ -191,7 +191,7 @@ def _fusable_matmul(
|
||||
)[0]
|
||||
|
||||
|
||||
def fusable_matmul(
|
||||
def fusible_matmul(
|
||||
x: jax.Array,
|
||||
y: jax.Array,
|
||||
*,
|
||||
@ -201,9 +201,9 @@ def fusable_matmul(
|
||||
debug: bool = False,
|
||||
interpret: bool = False,
|
||||
) -> jax.Array:
|
||||
return fuser.fusable(
|
||||
return fuser.fusible(
|
||||
functools.partial(
|
||||
_fusable_matmul,
|
||||
_fusible_matmul,
|
||||
bm=bm,
|
||||
bk=bk,
|
||||
bn=bn,
|
||||
@ -213,7 +213,7 @@ def fusable_matmul(
|
||||
)(x, y)
|
||||
|
||||
|
||||
class FusableMatmulTest(jtu.JaxTestCase):
|
||||
class FusibleMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
@ -226,7 +226,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
x = jax.random.normal(k0, (512, 512), dtype)
|
||||
y = jax.random.normal(k1, (512, 512), dtype)
|
||||
np.testing.assert_allclose(
|
||||
jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5
|
||||
jax.jit(fusible_matmul)(x, y), mm_ref(x, y), atol=5e-5
|
||||
)
|
||||
|
||||
@parameterized.parameters('float32', 'bfloat16')
|
||||
@ -238,7 +238,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_relu(x, y):
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
x = jnp.maximum(x, 0.0)
|
||||
return x
|
||||
|
||||
@ -258,7 +258,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_bias(x, y, b):
|
||||
x = fusable_matmul(x, y).astype(dtype) + b
|
||||
x = fusible_matmul(x, y).astype(dtype) + b
|
||||
x = jnp.maximum(x, 0.0)
|
||||
return x
|
||||
|
||||
@ -277,7 +277,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y):
|
||||
x = fusable_matmul(x, y[1])
|
||||
x = fusible_matmul(x, y[1])
|
||||
return x
|
||||
|
||||
np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5)
|
||||
@ -291,7 +291,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y, i):
|
||||
x = fusable_matmul(x, y[i])
|
||||
x = fusible_matmul(x, y[i])
|
||||
return x
|
||||
|
||||
np.testing.assert_allclose(
|
||||
@ -308,7 +308,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y, b, i, j):
|
||||
x = fusable_matmul(x, y[j]).astype(dtype) + b[i]
|
||||
x = fusible_matmul(x, y[j]).astype(dtype) + b[i]
|
||||
return x
|
||||
|
||||
np.testing.assert_allclose(
|
||||
@ -326,7 +326,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y):
|
||||
x = fusable_matmul(x, y[1, 1])
|
||||
x = fusible_matmul(x, y[1, 1])
|
||||
return x
|
||||
|
||||
np.testing.assert_allclose(
|
||||
@ -342,7 +342,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y):
|
||||
x = fusable_matmul(x, y[1][1])
|
||||
x = fusible_matmul(x, y[1][1])
|
||||
return x
|
||||
|
||||
np.testing.assert_allclose(
|
||||
@ -358,7 +358,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y, i, j):
|
||||
x = fusable_matmul(x, y[i][j])
|
||||
x = fusible_matmul(x, y[i][j])
|
||||
return x
|
||||
|
||||
for i in range(2):
|
||||
@ -376,7 +376,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y, i, j):
|
||||
x = fusable_matmul(x, y[2][i, j])
|
||||
x = fusible_matmul(x, y[2][i, j])
|
||||
return x
|
||||
|
||||
for i in range(2):
|
||||
@ -397,7 +397,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@fuser.fuse
|
||||
def matmul_slice(x, y, b, i, j, k):
|
||||
x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype)
|
||||
x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype)
|
||||
return x + b[i, j]
|
||||
|
||||
@jit_no_excess_precision
|
||||
@ -428,7 +428,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@fuser.fuse
|
||||
def matmul_concat(x, ys):
|
||||
y = jnp.concatenate(ys, axis=1)
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
return x
|
||||
|
||||
@jax.jit
|
||||
@ -454,7 +454,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@fuser.fuse
|
||||
def matmul_concat(x, ys):
|
||||
y = jnp.concatenate(ys, axis=0)
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
return x
|
||||
|
||||
@jit_no_excess_precision
|
||||
@ -482,7 +482,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
def matmul_concat(x, ys, y3):
|
||||
y = jnp.concatenate(ys, axis=0)
|
||||
y = jnp.concatenate([y, y3], axis=1)
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
return x
|
||||
|
||||
@jit_no_excess_precision
|
||||
@ -509,7 +509,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@fuser.fuse
|
||||
def matmul_concat(x, y1, y2):
|
||||
y = jnp.concatenate([y1, y2[3]], axis=0)
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
return x
|
||||
|
||||
@jit_no_excess_precision
|
||||
@ -534,7 +534,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@fuser.fuse
|
||||
def matmul_concat(x, y1, y2):
|
||||
y = jnp.concatenate([y1, y2[3]], axis=1)[1]
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
return x
|
||||
|
||||
@jit_no_excess_precision
|
||||
@ -559,7 +559,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
@fuser.fuse
|
||||
def matmul_concat(x, y1, y2, i, j):
|
||||
y = jnp.concatenate([y1, y2[i]], axis=1)[j]
|
||||
x = fusable_matmul(x, y)
|
||||
x = fusible_matmul(x, y)
|
||||
return x
|
||||
|
||||
@jit_no_excess_precision
|
||||
@ -585,7 +585,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -607,7 +607,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -629,7 +629,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -651,7 +651,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -673,7 +673,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -695,7 +695,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -716,7 +716,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -738,7 +738,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bm=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bm=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -760,7 +760,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bm=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bm=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -782,7 +782,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z.T
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -804,7 +804,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
|
||||
return z.T * 2
|
||||
|
||||
impl = fuser.fuse(
|
||||
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
|
||||
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
|
||||
)
|
||||
ref = functools.partial(matmul, mm_ref)
|
||||
|
||||
@ -867,7 +867,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
|
||||
impl = fuser.fuse(
|
||||
functools.partial(
|
||||
matmul,
|
||||
fusable_matmul,
|
||||
fusible_matmul,
|
||||
)
|
||||
)
|
||||
ref = functools.partial(matmul, dot_ref)
|
||||
@ -893,7 +893,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
|
||||
|
||||
out_ref = jit_no_excess_precision(ref)(x, y)
|
||||
|
||||
impl = fuser.fuse(functools.partial(matmul, fusable_matmul))
|
||||
impl = fuser.fuse(functools.partial(matmul, fusible_matmul))
|
||||
out = jax.jit(impl)(x, y)
|
||||
|
||||
self.assertAllClose(out, out_ref, atol=0)
|
||||
@ -917,7 +917,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
|
||||
impl = fuser.fuse(
|
||||
functools.partial(
|
||||
matmul,
|
||||
functools.partial(fusable_matmul, bk=256, bn=128),
|
||||
functools.partial(fusible_matmul, bk=256, bn=128),
|
||||
)
|
||||
)
|
||||
out = jax.jit(impl)(x, y)
|
||||
@ -953,7 +953,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
|
||||
functools.partial(
|
||||
matmul,
|
||||
functools.partial(
|
||||
fusable_matmul,
|
||||
fusible_matmul,
|
||||
bm=bm,
|
||||
bk=bk,
|
||||
bn=bn,
|
||||
@ -990,7 +990,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
|
||||
functools.partial(
|
||||
matmul,
|
||||
functools.partial(
|
||||
fusable_matmul,
|
||||
fusible_matmul,
|
||||
bm=bm,
|
||||
bk=bk,
|
||||
bn=bn,
|
||||
@ -1025,7 +1025,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
|
||||
functools.partial(
|
||||
matmul,
|
||||
functools.partial(
|
||||
fusable_matmul,
|
||||
fusible_matmul,
|
||||
bm=bm,
|
||||
bk=bk,
|
||||
bn=bn,
|
Loading…
x
Reference in New Issue
Block a user