[pallas] Fix spelling of 'fusible'.

PiperOrigin-RevId: 747663692
This commit is contained in:
Chris Jones 2025-04-14 19:35:09 -07:00 committed by jax authors
parent 0ed0fb7c54
commit 1926b99bfd
10 changed files with 108 additions and 106 deletions

View File

@ -721,7 +721,7 @@ pytype_strict_library(
":pallas", # build_cleaner: keep
"//jax/_src/pallas/fuser:block_spec",
"//jax/_src/pallas/fuser:custom_evaluate",
"//jax/_src/pallas/fuser:fusable",
"//jax/_src/pallas/fuser:fusible",
"//jax/_src/pallas/fuser:fusion",
"//jax/_src/pallas/fuser:jaxpr_fusion",
],

View File

@ -33,7 +33,7 @@ pytype_strict_library(
deps = [
":block_spec",
":custom_evaluate",
":fusable",
":fusible",
":fusion",
":jaxpr_fusion",
],
@ -58,9 +58,9 @@ pytype_strict_library(
)
pytype_strict_library(
name = "fusable",
name = "fusible",
srcs = [
"fusable.py",
"fusible.py",
],
deps = [
":fusion",
@ -91,8 +91,8 @@ pytype_strict_library(
"jaxpr_fusion.py",
],
deps = [
":fusable",
":fusable_dtype",
":fusible",
":fusible_dtype",
":fusion",
"//jax",
"//jax:api_util",
@ -104,13 +104,13 @@ pytype_strict_library(
)
pytype_strict_library(
name = "fusable_dtype",
name = "fusible_dtype",
srcs = [
"fusable_dtype.py",
"fusible_dtype.py",
],
deps = [
":block_spec",
":fusable",
":fusible",
"//jax",
"//jax:api_util",
"//jax:core",

View File

@ -17,6 +17,6 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as mak
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
from jax._src.pallas.fuser.fusable import fusable as fusable
from jax._src.pallas.fuser.fusible import fusible as fusible
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fusable primitive."""
"""Fusible primitive."""
from typing import Any
import jax
@ -25,8 +25,8 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.fuser import fusion as fusion_lib
fusable_p = jax_core.Primitive('fusable')
fusable_p.multiple_results = True
fusible_p = jax_core.Primitive('fusible')
fusible_p.multiple_results = True
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
@ -37,7 +37,7 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
)
def fusable(f=None, *, output_fusion_prefix: Any = True):
def fusible(f=None, *, output_fusion_prefix: Any = True):
def decorator(f):
def wrapper(*args):
def wrapped(*args):
@ -45,14 +45,14 @@ def fusable(f=None, *, output_fusion_prefix: Any = True):
return f(*in_fusions, None)
flat_args, in_tree = tree_util.tree_flatten(args)
debug_info = api_util.debug_info('fusable', wrapped, args, {})
debug_info = api_util.debug_info('fusible', wrapped, args, {})
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
)
flat_avals = [jax_core.get_aval(x) for x in flat_args]
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
out = fusable_p.bind(
out = fusible_p.bind(
*consts,
*flat_args,
jaxpr=jaxpr,
@ -71,16 +71,16 @@ def fusable(f=None, *, output_fusion_prefix: Any = True):
return decorator
@fusable_p.def_impl
@fusible_p.def_impl
def _(*consts_and_args, jaxpr, num_consts, **_):
consts, args = util.split_list(consts_and_args, [num_consts])
return jax_core.eval_jaxpr(jaxpr, consts, *args)
mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl))
mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl))
@fusable_p.def_abstract_eval
@fusible_p.def_abstract_eval
def _(*args, jaxpr, **kwargs):
del args, kwargs
return [v.aval for v in jaxpr.outvars]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom fusable dtypes."""
"""Custom fusible dtypes."""
import abc
import dataclasses
@ -34,7 +34,7 @@ from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives as pallas_primitives
from jax._src.pallas.fuser import block_spec
from jax._src.pallas.fuser.fusable import fusable_p
from jax._src.pallas.fuser.fusible import fusible_p
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
from jax._src.util import foreach
@ -54,7 +54,7 @@ pack_dtype_p = core.Primitive("pack_dtype")
@pack_dtype_p.def_abstract_eval
def pack_dtype_abstract_eval(*xs, dtype):
if dtypes.issubdtype(dtype, FusableElementDType):
if dtypes.issubdtype(dtype, fusibleElementDType):
return dtype.abstract_pack(*xs)
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
@ -69,7 +69,7 @@ unpack_dtype_p.multiple_results = True
@unpack_dtype_p.def_abstract_eval
def unpack_dtype_abstract_eval(x):
if dtypes.issubdtype(x.dtype, FusableElementDType):
if dtypes.issubdtype(x.dtype, fusibleElementDType):
return x.dtype.abstract_unpack(x)
elif isinstance(x.dtype, pallas_core.AbstractMemoryRef):
raise NotImplementedError()
@ -80,20 +80,20 @@ def unpack(x):
return unpack_dtype_p.bind(x)
class FusableElementDType(dtypes.extended):
"""Scalar dtype for fusable dtypes."""
class fusibleElementDType(dtypes.extended):
"""Scalar dtype for fusible dtypes."""
class FusableTyRules:
class fusibleTyRules:
allow_conversion: bool = False
class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
"""Base class for fusable extended dtypes."""
"""Base class for fusible extended dtypes."""
_op_registry = {}
_rules = FusableTyRules
type = FusableElementDType
_rules = fusibleTyRules
type = fusibleElementDType
@abc.abstractmethod
def abstract_unpack(self, x) -> Sequence[Any]:
@ -124,7 +124,7 @@ class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
def physicalize(f):
"""Runs a function that contains fusable extended dtypes."""
"""Runs a function that contains fusible extended dtypes."""
def wrapper(*args, **kwargs):
if kwargs:
@ -203,7 +203,7 @@ class Context:
def physicalize_interp(
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
):
"""Physicalizes a jaxpr by replacing fusable dtypes with physical types."""
"""Physicalizes a jaxpr by replacing fusible dtypes with physical types."""
# TODO: Merge into JAX core.
env: dict[core.Var, Any] = {}
@ -446,12 +446,12 @@ def _pack_dtype_pull_rule(
return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error
def _fusable_physicalize_rule(
def _fusible_physicalize_rule(
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
):
consts, _ = util.split_list(consts_and_args, [num_consts])
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
return fusable_p.bind(
return fusible_p.bind(
*consts_and_args,
jaxpr=new_jaxpr.jaxpr,
num_consts=num_consts,
@ -461,4 +461,4 @@ def _fusable_physicalize_rule(
)
_physicalize_rules[fusable_p] = _fusable_physicalize_rule
_physicalize_rules[fusible_p] = _fusible_physicalize_rule

View File

@ -23,22 +23,22 @@ from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.fuser import fusable_dtype
from jax._src.pallas.fuser import fusible_dtype
from jax._src.pallas.fuser import fusion as fusion_lib
from jax._src.pallas.fuser.fusable import fusable_p
from jax._src.pallas.fuser.fusible import fusible_p
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
"""Fuses a function into a single fusable.
"""Fuses a function into a single fusible.
Args:
f: The function to fuse.
physicalize: (experimental) whether to physicalize the function.
debug: Whether to print debug information.
There should be a single call to a `fusable` inside the body of `f`. `fuse`
There should be a single call to a `fusible` inside the body of `f`. `fuse`
returns a transformed function that will fuse the surrounding computation into
the fusable and invoke it.
the fusible and invoke it.
"""
def decorator(f):
@ -58,7 +58,7 @@ def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
return tree_util.tree_unflatten(out_tree, out_flat)
if physicalize:
wrapper = fusable_dtype.physicalize(wrapper)
wrapper = fusible_dtype.physicalize(wrapper)
return wrapper
if f is not None:
@ -66,7 +66,7 @@ def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
return decorator
_fusable: dict[jax_core.Primitive, Any] = {}
_fusible: dict[jax_core.Primitive, Any] = {}
def _construct_fusion_jaxpr(
@ -148,11 +148,11 @@ def _construct_output_fusions(
jaxpr,
out_tree,
fusion_eqn_index,
fusion_eqn_outvars, # Flat list of vars output by the fusable eqn
fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs
fusion_eqn_outvars, # Flat list of vars output by the fusible eqn
fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs
output_fusion_prefix, # Pytree defining output groups
):
# 1. Create jaxpr_out: represents computation *after* the fusable
# 1. Create jaxpr_out: represents computation *after* the fusible
# Inputs: fusion_eqn_outvars
# Outputs: jaxpr.outvars
jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr(
@ -164,15 +164,15 @@ def _construct_output_fusions(
tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs
tree_util.tree_unflatten(
fusion_eqn_out_tree, fusion_eqn_outvars
), # Fusable outputs as inputs
), # Fusible outputs as inputs
)
# 2. Group fusable outputs based on the mask
unflat_fusable_outvars = jax.tree.unflatten(
# 2. Group fusible outputs based on the mask
unflat_fusible_outvars = jax.tree.unflatten(
fusion_eqn_out_tree, fusion_eqn_outvars
)
partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to(
unflat_fusable_outvars
unflat_fusible_outvars
)
# 3. Calculate dependencies and check disjointness
@ -180,10 +180,10 @@ def _construct_output_fusions(
already_used_final_outputs = set() # Indices of final outputs already claimed
for outvars_group in partial_flat:
# Identify vars in this group
used_fusable_outvars = set(jax.tree.leaves(outvars_group))
used_fusible_outvars = set(jax.tree.leaves(outvars_group))
# Create mask for jaxpr_out inputs corresponding to this group
in_used_mask = [
True if v in used_fusable_outvars else False for v in jaxpr_out.invars
True if v in used_fusible_outvars else False for v in jaxpr_out.invars
]
# Trace dependencies through jaxpr_out to find which final outputs are affected
downstream_used_mask = _find_downstream(
@ -257,11 +257,11 @@ def fuse_jaxpr(
# Collect input fusions
for i, eqn in enumerate(jaxpr.eqns):
if eqn.primitive is fusable_p:
if eqn.primitive is fusible_p:
fusion_eqn_index = i
break
if fusion_eqn_index is None:
raise ValueError("No fusable eqn found")
raise ValueError("No fusible eqn found")
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
# Now let's check if we need to do any fusion at all, e.g. do the outputs of
@ -269,13 +269,13 @@ def fuse_jaxpr(
# with all the inputs and outputs to check if there is a dependence.
dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars),
instantiate=True)
if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns):
if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns):
# Short circuit if there is nothing to fuse.
return jax_core.eval_jaxpr(dced_jaxpr, consts, *args)
candidate_values = [*consts, *args]
# Construct fusions for non-constant inputs to the fusable.
# Construct fusions for non-constant inputs to the fusible.
in_fusions_flat = [
construct_fusion(
candidate_values,

View File

@ -19,6 +19,6 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as mak
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
from jax._src.pallas.fuser.fusable import fusable as fusable
from jax._src.pallas.fuser.fusible import fusible as fusible
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

View File

@ -702,8 +702,8 @@ jax_multiplatform_test(
)
jax_multiplatform_test(
name = "tpu_fusable_matmul_test",
srcs = ["tpu_fusable_matmul_test.py"],
name = "tpu_fusible_matmul_test",
srcs = ["tpu_fusible_matmul_test.py"],
disable_configs = [
"tpu_v3",
"tpu_pjrt_c_api",

View File

@ -28,7 +28,7 @@ class FusionTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
@fuser.fusable
@fuser.fusible
def f(x_fn, y_fn):
x = x_fn()
if y_fn is None:
@ -40,7 +40,7 @@ class FusionTest(jtu.JaxTestCase):
def test_separate_output_fusions_trivial(self):
@fuser.fusable(output_fusion_prefix=(True, True))
@fuser.fusible(output_fusion_prefix=(True, True))
def f(x_fn, y_fn, z_fns):
x = x_fn()
y = y_fn()
@ -63,7 +63,7 @@ class FusionTest(jtu.JaxTestCase):
def test_separate_output_fusions_should_error_if_not_disjoint(self):
@fuser.fusable(output_fusion_prefix=(True, True))
@fuser.fusible(output_fusion_prefix=(True, True))
def f(x_fn, y_fn, z_fns):
x = x_fn()
y = y_fn()
@ -89,7 +89,7 @@ class FusionTest(jtu.JaxTestCase):
def test_separate_output_fusions_allows_permute(self):
@fuser.fusable(output_fusion_prefix=(True, True))
@fuser.fusible(output_fusion_prefix=(True, True))
def f(x_fn, y_fn, z_fns):
x = x_fn()
y = y_fn()
@ -112,7 +112,7 @@ class FusionTest(jtu.JaxTestCase):
def test_separate_output_fusions_with_nesting(self):
@fuser.fusable(output_fusion_prefix=(True, True))
@fuser.fusible(output_fusion_prefix=(True, True))
def f(x_fn, y_fn, z_fns):
x = x_fn()
y = y_fn()
@ -136,7 +136,7 @@ class FusionTest(jtu.JaxTestCase):
def test_separate_output_fusions_with_nesting_and_permutation(self):
@fuser.fusable(output_fusion_prefix=(True, True))
@fuser.fusible(output_fusion_prefix=(True, True))
def f(x_fn, y_fn, z_fns):
x = x_fn()
y = y_fn()
@ -160,7 +160,7 @@ class FusionTest(jtu.JaxTestCase):
def test_separate_output_fusions_with_deep_output_mask(self):
@fuser.fusable(output_fusion_prefix=(True, (True, True)))
@fuser.fusible(output_fusion_prefix=(True, (True, True)))
def f(x_fn, y_fn, z_fn, o_fns):
x = x_fn()
y = y_fn()
@ -185,7 +185,8 @@ class FusionTest(jtu.JaxTestCase):
np.testing.assert_array_equal(z_out, z + z)
def test_separate_output_fusions_with_reused_value(self):
@fuser.fusable(output_fusion_prefix=(True, True))
@fuser.fusible(output_fusion_prefix=(True, True))
def f(x_fn, y_fn, z_fns):
x = x_fn()
y = y_fn()
@ -209,7 +210,8 @@ class FusionTest(jtu.JaxTestCase):
np.testing.assert_array_equal(y_out, y + a)
def test_empty_fusion(self):
@fuser.fusable
@fuser.fusible
def f(x_fn, y_fn):
x = x_fn()
if y_fn is None:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fusable matmul test."""
"""Fusible matmul test."""
import functools
from typing import Any
@ -75,7 +75,7 @@ def matmul_kernel(
jax.tree.map(lambda ref, x: ref.set(x), o_ref, out)
def _fusable_matmul(
def _fusible_matmul(
x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation
y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation
z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation
@ -191,7 +191,7 @@ def _fusable_matmul(
)[0]
def fusable_matmul(
def fusible_matmul(
x: jax.Array,
y: jax.Array,
*,
@ -201,9 +201,9 @@ def fusable_matmul(
debug: bool = False,
interpret: bool = False,
) -> jax.Array:
return fuser.fusable(
return fuser.fusible(
functools.partial(
_fusable_matmul,
_fusible_matmul,
bm=bm,
bk=bk,
bn=bn,
@ -213,7 +213,7 @@ def fusable_matmul(
)(x, y)
class FusableMatmulTest(jtu.JaxTestCase):
class FusibleMatmulTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.is_device_tpu_at_least(4):
@ -226,7 +226,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
x = jax.random.normal(k0, (512, 512), dtype)
y = jax.random.normal(k1, (512, 512), dtype)
np.testing.assert_allclose(
jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5
jax.jit(fusible_matmul)(x, y), mm_ref(x, y), atol=5e-5
)
@parameterized.parameters('float32', 'bfloat16')
@ -238,7 +238,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_relu(x, y):
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
x = jnp.maximum(x, 0.0)
return x
@ -258,7 +258,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_bias(x, y, b):
x = fusable_matmul(x, y).astype(dtype) + b
x = fusible_matmul(x, y).astype(dtype) + b
x = jnp.maximum(x, 0.0)
return x
@ -277,7 +277,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y):
x = fusable_matmul(x, y[1])
x = fusible_matmul(x, y[1])
return x
np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5)
@ -291,7 +291,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y, i):
x = fusable_matmul(x, y[i])
x = fusible_matmul(x, y[i])
return x
np.testing.assert_allclose(
@ -308,7 +308,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y, b, i, j):
x = fusable_matmul(x, y[j]).astype(dtype) + b[i]
x = fusible_matmul(x, y[j]).astype(dtype) + b[i]
return x
np.testing.assert_allclose(
@ -326,7 +326,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y):
x = fusable_matmul(x, y[1, 1])
x = fusible_matmul(x, y[1, 1])
return x
np.testing.assert_allclose(
@ -342,7 +342,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y):
x = fusable_matmul(x, y[1][1])
x = fusible_matmul(x, y[1][1])
return x
np.testing.assert_allclose(
@ -358,7 +358,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y, i, j):
x = fusable_matmul(x, y[i][j])
x = fusible_matmul(x, y[i][j])
return x
for i in range(2):
@ -376,7 +376,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y, i, j):
x = fusable_matmul(x, y[2][i, j])
x = fusible_matmul(x, y[2][i, j])
return x
for i in range(2):
@ -397,7 +397,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@jax.jit
@fuser.fuse
def matmul_slice(x, y, b, i, j, k):
x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype)
x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype)
return x + b[i, j]
@jit_no_excess_precision
@ -428,7 +428,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@fuser.fuse
def matmul_concat(x, ys):
y = jnp.concatenate(ys, axis=1)
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
return x
@jax.jit
@ -454,7 +454,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@fuser.fuse
def matmul_concat(x, ys):
y = jnp.concatenate(ys, axis=0)
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
return x
@jit_no_excess_precision
@ -482,7 +482,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
def matmul_concat(x, ys, y3):
y = jnp.concatenate(ys, axis=0)
y = jnp.concatenate([y, y3], axis=1)
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
return x
@jit_no_excess_precision
@ -509,7 +509,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@fuser.fuse
def matmul_concat(x, y1, y2):
y = jnp.concatenate([y1, y2[3]], axis=0)
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
return x
@jit_no_excess_precision
@ -534,7 +534,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@fuser.fuse
def matmul_concat(x, y1, y2):
y = jnp.concatenate([y1, y2[3]], axis=1)[1]
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
return x
@jit_no_excess_precision
@ -559,7 +559,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
@fuser.fuse
def matmul_concat(x, y1, y2, i, j):
y = jnp.concatenate([y1, y2[i]], axis=1)[j]
x = fusable_matmul(x, y)
x = fusible_matmul(x, y)
return x
@jit_no_excess_precision
@ -585,7 +585,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -607,7 +607,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -629,7 +629,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -651,7 +651,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -673,7 +673,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -695,7 +695,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -716,7 +716,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -738,7 +738,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bm=256))
functools.partial(matmul, functools.partial(fusible_matmul, bm=256))
)
ref = functools.partial(matmul, mm_ref)
@ -760,7 +760,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bm=256))
functools.partial(matmul, functools.partial(fusible_matmul, bm=256))
)
ref = functools.partial(matmul, mm_ref)
@ -782,7 +782,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z.T
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -804,7 +804,7 @@ class FusableMatmulTest(jtu.JaxTestCase):
return z.T * 2
impl = fuser.fuse(
functools.partial(matmul, functools.partial(fusable_matmul, bn=256))
functools.partial(matmul, functools.partial(fusible_matmul, bn=256))
)
ref = functools.partial(matmul, mm_ref)
@ -867,7 +867,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
impl = fuser.fuse(
functools.partial(
matmul,
fusable_matmul,
fusible_matmul,
)
)
ref = functools.partial(matmul, dot_ref)
@ -893,7 +893,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
out_ref = jit_no_excess_precision(ref)(x, y)
impl = fuser.fuse(functools.partial(matmul, fusable_matmul))
impl = fuser.fuse(functools.partial(matmul, fusible_matmul))
out = jax.jit(impl)(x, y)
self.assertAllClose(out, out_ref, atol=0)
@ -917,7 +917,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
impl = fuser.fuse(
functools.partial(
matmul,
functools.partial(fusable_matmul, bk=256, bn=128),
functools.partial(fusible_matmul, bk=256, bn=128),
)
)
out = jax.jit(impl)(x, y)
@ -953,7 +953,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
functools.partial(
matmul,
functools.partial(
fusable_matmul,
fusible_matmul,
bm=bm,
bk=bk,
bn=bn,
@ -990,7 +990,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
functools.partial(
matmul,
functools.partial(
fusable_matmul,
fusible_matmul,
bm=bm,
bk=bk,
bn=bn,
@ -1025,7 +1025,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase):
functools.partial(
matmul,
functools.partial(
fusable_matmul,
fusible_matmul,
bm=bm,
bk=bk,
bn=bn,