From 1926b99bfd3524f8c62521db454f69dcb942a1e5 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 14 Apr 2025 19:35:09 -0700 Subject: [PATCH] [pallas] Fix spelling of 'fusible'. PiperOrigin-RevId: 747663692 --- jax/BUILD | 2 +- jax/_src/pallas/fuser/BUILD | 16 ++-- jax/_src/pallas/fuser/__init__.py | 2 +- .../pallas/fuser/{fusable.py => fusible.py} | 18 ++--- .../{fusable_dtype.py => fusible_dtype.py} | 30 +++---- jax/_src/pallas/fuser/jaxpr_fusion.py | 40 +++++----- jax/experimental/pallas/fuser.py | 2 +- tests/pallas/BUILD | 4 +- tests/pallas/fusion_test.py | 20 ++--- ...mul_test.py => tpu_fusible_matmul_test.py} | 80 +++++++++---------- 10 files changed, 108 insertions(+), 106 deletions(-) rename jax/_src/pallas/fuser/{fusable.py => fusible.py} (86%) rename jax/_src/pallas/fuser/{fusable_dtype.py => fusible_dtype.py} (95%) rename tests/pallas/{tpu_fusable_matmul_test.py => tpu_fusible_matmul_test.py} (94%) diff --git a/jax/BUILD b/jax/BUILD index cb9a39efb..ca0fb0268 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -721,7 +721,7 @@ pytype_strict_library( ":pallas", # build_cleaner: keep "//jax/_src/pallas/fuser:block_spec", "//jax/_src/pallas/fuser:custom_evaluate", - "//jax/_src/pallas/fuser:fusable", + "//jax/_src/pallas/fuser:fusible", "//jax/_src/pallas/fuser:fusion", "//jax/_src/pallas/fuser:jaxpr_fusion", ], diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 8339ad670..a62a9937d 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -33,7 +33,7 @@ pytype_strict_library( deps = [ ":block_spec", ":custom_evaluate", - ":fusable", + ":fusible", ":fusion", ":jaxpr_fusion", ], @@ -58,9 +58,9 @@ pytype_strict_library( ) pytype_strict_library( - name = "fusable", + name = "fusible", srcs = [ - "fusable.py", + "fusible.py", ], deps = [ ":fusion", @@ -91,8 +91,8 @@ pytype_strict_library( "jaxpr_fusion.py", ], deps = [ - ":fusable", - ":fusable_dtype", + ":fusible", + ":fusible_dtype", ":fusion", "//jax", "//jax:api_util", @@ -104,13 +104,13 @@ pytype_strict_library( ) pytype_strict_library( - name = "fusable_dtype", + name = "fusible_dtype", srcs = [ - "fusable_dtype.py", + "fusible_dtype.py", ], deps = [ ":block_spec", - ":fusable", + ":fusible", "//jax", "//jax:api_util", "//jax:core", diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index 3295c8f10..39720100e 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -17,6 +17,6 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as mak from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusible.py similarity index 86% rename from jax/_src/pallas/fuser/fusable.py rename to jax/_src/pallas/fuser/fusible.py index d9d0ee0b4..289a9dc26 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusible.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable primitive.""" +"""Fusible primitive.""" from typing import Any import jax @@ -25,8 +25,8 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas.fuser import fusion as fusion_lib -fusable_p = jax_core.Primitive('fusable') -fusable_p.multiple_results = True +fusible_p = jax_core.Primitive('fusible') +fusible_p.multiple_results = True def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: @@ -37,7 +37,7 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: ) -def fusable(f=None, *, output_fusion_prefix: Any = True): +def fusible(f=None, *, output_fusion_prefix: Any = True): def decorator(f): def wrapper(*args): def wrapped(*args): @@ -45,14 +45,14 @@ def fusable(f=None, *, output_fusion_prefix: Any = True): return f(*in_fusions, None) flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) + debug_info = api_util.debug_info('fusible', wrapped, args, {}) flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(wrapped, debug_info=debug_info), in_tree ) flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) out_tree = out_tree_thunk() - out = fusable_p.bind( + out = fusible_p.bind( *consts, *flat_args, jaxpr=jaxpr, @@ -71,16 +71,16 @@ def fusable(f=None, *, output_fusion_prefix: Any = True): return decorator -@fusable_p.def_impl +@fusible_p.def_impl def _(*consts_and_args, jaxpr, num_consts, **_): consts, args = util.split_list(consts_and_args, [num_consts]) return jax_core.eval_jaxpr(jaxpr, consts, *args) -mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl)) +mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl)) -@fusable_p.def_abstract_eval +@fusible_p.def_abstract_eval def _(*args, jaxpr, **kwargs): del args, kwargs return [v.aval for v in jaxpr.outvars] diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py similarity index 95% rename from jax/_src/pallas/fuser/fusable_dtype.py rename to jax/_src/pallas/fuser/fusible_dtype.py index 99c80e652..8e6cfefcc 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Custom fusable dtypes.""" +"""Custom fusible dtypes.""" import abc import dataclasses @@ -34,7 +34,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.fuser import block_spec -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives from jax._src.util import foreach @@ -54,7 +54,7 @@ pack_dtype_p = core.Primitive("pack_dtype") @pack_dtype_p.def_abstract_eval def pack_dtype_abstract_eval(*xs, dtype): - if dtypes.issubdtype(dtype, FusableElementDType): + if dtypes.issubdtype(dtype, fusibleElementDType): return dtype.abstract_pack(*xs) raise ValueError("Attempted to pack non-fusion dtype: {dtype}") @@ -69,7 +69,7 @@ unpack_dtype_p.multiple_results = True @unpack_dtype_p.def_abstract_eval def unpack_dtype_abstract_eval(x): - if dtypes.issubdtype(x.dtype, FusableElementDType): + if dtypes.issubdtype(x.dtype, fusibleElementDType): return x.dtype.abstract_unpack(x) elif isinstance(x.dtype, pallas_core.AbstractMemoryRef): raise NotImplementedError() @@ -80,20 +80,20 @@ def unpack(x): return unpack_dtype_p.bind(x) -class FusableElementDType(dtypes.extended): - """Scalar dtype for fusable dtypes.""" +class fusibleElementDType(dtypes.extended): + """Scalar dtype for fusible dtypes.""" -class FusableTyRules: +class fusibleTyRules: allow_conversion: bool = False class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): - """Base class for fusable extended dtypes.""" + """Base class for fusible extended dtypes.""" _op_registry = {} - _rules = FusableTyRules - type = FusableElementDType + _rules = fusibleTyRules + type = fusibleElementDType @abc.abstractmethod def abstract_unpack(self, x) -> Sequence[Any]: @@ -124,7 +124,7 @@ class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): def physicalize(f): - """Runs a function that contains fusable extended dtypes.""" + """Runs a function that contains fusible extended dtypes.""" def wrapper(*args, **kwargs): if kwargs: @@ -203,7 +203,7 @@ class Context: def physicalize_interp( jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value ): - """Physicalizes a jaxpr by replacing fusable dtypes with physical types.""" + """Physicalizes a jaxpr by replacing fusible dtypes with physical types.""" # TODO: Merge into JAX core. env: dict[core.Var, Any] = {} @@ -446,12 +446,12 @@ def _pack_dtype_pull_rule( return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error -def _fusable_physicalize_rule( +def _fusible_physicalize_rule( _, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func ): consts, _ = util.split_list(consts_and_args, [num_consts]) new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts)) - return fusable_p.bind( + return fusible_p.bind( *consts_and_args, jaxpr=new_jaxpr.jaxpr, num_consts=num_consts, @@ -461,4 +461,4 @@ def _fusable_physicalize_rule( ) -_physicalize_rules[fusable_p] = _fusable_physicalize_rule +_physicalize_rules[fusible_p] = _fusible_physicalize_rule diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3c3c2a3d7..95768d71f 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -23,22 +23,22 @@ from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import tree_util from jax._src.interpreters import partial_eval as pe -from jax._src.pallas.fuser import fusable_dtype +from jax._src.pallas.fuser import fusible_dtype from jax._src.pallas.fuser import fusion as fusion_lib -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p def fuse(f=None, *, physicalize: bool = False, debug: bool = False): - """Fuses a function into a single fusable. + """Fuses a function into a single fusible. Args: f: The function to fuse. physicalize: (experimental) whether to physicalize the function. debug: Whether to print debug information. - There should be a single call to a `fusable` inside the body of `f`. `fuse` + There should be a single call to a `fusible` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into - the fusable and invoke it. + the fusible and invoke it. """ def decorator(f): @@ -58,7 +58,7 @@ def fuse(f=None, *, physicalize: bool = False, debug: bool = False): return tree_util.tree_unflatten(out_tree, out_flat) if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) + wrapper = fusible_dtype.physicalize(wrapper) return wrapper if f is not None: @@ -66,7 +66,7 @@ def fuse(f=None, *, physicalize: bool = False, debug: bool = False): return decorator -_fusable: dict[jax_core.Primitive, Any] = {} +_fusible: dict[jax_core.Primitive, Any] = {} def _construct_fusion_jaxpr( @@ -148,11 +148,11 @@ def _construct_output_fusions( jaxpr, out_tree, fusion_eqn_index, - fusion_eqn_outvars, # Flat list of vars output by the fusable eqn - fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs + fusion_eqn_outvars, # Flat list of vars output by the fusible eqn + fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs output_fusion_prefix, # Pytree defining output groups ): - # 1. Create jaxpr_out: represents computation *after* the fusable + # 1. Create jaxpr_out: represents computation *after* the fusible # Inputs: fusion_eqn_outvars # Outputs: jaxpr.outvars jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( @@ -164,15 +164,15 @@ def _construct_output_fusions( tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs tree_util.tree_unflatten( fusion_eqn_out_tree, fusion_eqn_outvars - ), # Fusable outputs as inputs + ), # Fusible outputs as inputs ) - # 2. Group fusable outputs based on the mask - unflat_fusable_outvars = jax.tree.unflatten( + # 2. Group fusible outputs based on the mask + unflat_fusible_outvars = jax.tree.unflatten( fusion_eqn_out_tree, fusion_eqn_outvars ) partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( - unflat_fusable_outvars + unflat_fusible_outvars ) # 3. Calculate dependencies and check disjointness @@ -180,10 +180,10 @@ def _construct_output_fusions( already_used_final_outputs = set() # Indices of final outputs already claimed for outvars_group in partial_flat: # Identify vars in this group - used_fusable_outvars = set(jax.tree.leaves(outvars_group)) + used_fusible_outvars = set(jax.tree.leaves(outvars_group)) # Create mask for jaxpr_out inputs corresponding to this group in_used_mask = [ - True if v in used_fusable_outvars else False for v in jaxpr_out.invars + True if v in used_fusible_outvars else False for v in jaxpr_out.invars ] # Trace dependencies through jaxpr_out to find which final outputs are affected downstream_used_mask = _find_downstream( @@ -257,11 +257,11 @@ def fuse_jaxpr( # Collect input fusions for i, eqn in enumerate(jaxpr.eqns): - if eqn.primitive is fusable_p: + if eqn.primitive is fusible_p: fusion_eqn_index = i break if fusion_eqn_index is None: - raise ValueError("No fusable eqn found") + raise ValueError("No fusible eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] # Now let's check if we need to do any fusion at all, e.g. do the outputs of @@ -269,13 +269,13 @@ def fuse_jaxpr( # with all the inputs and outputs to check if there is a dependence. dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) - if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns): + if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns): # Short circuit if there is nothing to fuse. return jax_core.eval_jaxpr(dced_jaxpr, consts, *args) candidate_values = [*consts, *args] - # Construct fusions for non-constant inputs to the fusable. + # Construct fusions for non-constant inputs to the fusible. in_fusions_flat = [ construct_fusion( candidate_values, diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 729a447b7..d4ec7e89c 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -19,6 +19,6 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as mak from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 730c354d8..fa98c0af4 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -702,8 +702,8 @@ jax_multiplatform_test( ) jax_multiplatform_test( - name = "tpu_fusable_matmul_test", - srcs = ["tpu_fusable_matmul_test.py"], + name = "tpu_fusible_matmul_test", + srcs = ["tpu_fusible_matmul_test.py"], disable_configs = [ "tpu_v3", "tpu_pjrt_c_api", diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py index 2edcf78f1..4bd02345c 100644 --- a/tests/pallas/fusion_test.py +++ b/tests/pallas/fusion_test.py @@ -28,7 +28,7 @@ class FusionTest(jtu.JaxTestCase): @jax.jit @fuser.fuse - @fuser.fusable + @fuser.fusible def f(x_fn, y_fn): x = x_fn() if y_fn is None: @@ -40,7 +40,7 @@ class FusionTest(jtu.JaxTestCase): def test_separate_output_fusions_trivial(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -63,7 +63,7 @@ class FusionTest(jtu.JaxTestCase): def test_separate_output_fusions_should_error_if_not_disjoint(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -89,7 +89,7 @@ class FusionTest(jtu.JaxTestCase): def test_separate_output_fusions_allows_permute(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -112,7 +112,7 @@ class FusionTest(jtu.JaxTestCase): def test_separate_output_fusions_with_nesting(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -136,7 +136,7 @@ class FusionTest(jtu.JaxTestCase): def test_separate_output_fusions_with_nesting_and_permutation(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -160,7 +160,7 @@ class FusionTest(jtu.JaxTestCase): def test_separate_output_fusions_with_deep_output_mask(self): - @fuser.fusable(output_fusion_prefix=(True, (True, True))) + @fuser.fusible(output_fusion_prefix=(True, (True, True))) def f(x_fn, y_fn, z_fn, o_fns): x = x_fn() y = y_fn() @@ -185,7 +185,8 @@ class FusionTest(jtu.JaxTestCase): np.testing.assert_array_equal(z_out, z + z) def test_separate_output_fusions_with_reused_value(self): - @fuser.fusable(output_fusion_prefix=(True, True)) + + @fuser.fusible(output_fusion_prefix=(True, True)) def f(x_fn, y_fn, z_fns): x = x_fn() y = y_fn() @@ -209,7 +210,8 @@ class FusionTest(jtu.JaxTestCase): np.testing.assert_array_equal(y_out, y + a) def test_empty_fusion(self): - @fuser.fusable + + @fuser.fusible def f(x_fn, y_fn): x = x_fn() if y_fn is None: diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py similarity index 94% rename from tests/pallas/tpu_fusable_matmul_test.py rename to tests/pallas/tpu_fusible_matmul_test.py index 93523b174..2382c09f2 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable matmul test.""" +"""Fusible matmul test.""" import functools from typing import Any @@ -75,7 +75,7 @@ def matmul_kernel( jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) -def _fusable_matmul( +def _fusible_matmul( x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation @@ -191,7 +191,7 @@ def _fusable_matmul( )[0] -def fusable_matmul( +def fusible_matmul( x: jax.Array, y: jax.Array, *, @@ -201,9 +201,9 @@ def fusable_matmul( debug: bool = False, interpret: bool = False, ) -> jax.Array: - return fuser.fusable( + return fuser.fusible( functools.partial( - _fusable_matmul, + _fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -213,7 +213,7 @@ def fusable_matmul( )(x, y) -class FusableMatmulTest(jtu.JaxTestCase): +class FusibleMatmulTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_device_tpu_at_least(4): @@ -226,7 +226,7 @@ class FusableMatmulTest(jtu.JaxTestCase): x = jax.random.normal(k0, (512, 512), dtype) y = jax.random.normal(k1, (512, 512), dtype) np.testing.assert_allclose( - jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5 + jax.jit(fusible_matmul)(x, y), mm_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -238,7 +238,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_relu(x, y): - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) x = jnp.maximum(x, 0.0) return x @@ -258,7 +258,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_bias(x, y, b): - x = fusable_matmul(x, y).astype(dtype) + b + x = fusible_matmul(x, y).astype(dtype) + b x = jnp.maximum(x, 0.0) return x @@ -277,7 +277,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1]) + x = fusible_matmul(x, y[1]) return x np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5) @@ -291,7 +291,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y, i): - x = fusable_matmul(x, y[i]) + x = fusible_matmul(x, y[i]) return x np.testing.assert_allclose( @@ -308,7 +308,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j): - x = fusable_matmul(x, y[j]).astype(dtype) + b[i] + x = fusible_matmul(x, y[j]).astype(dtype) + b[i] return x np.testing.assert_allclose( @@ -326,7 +326,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1, 1]) + x = fusible_matmul(x, y[1, 1]) return x np.testing.assert_allclose( @@ -342,7 +342,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1][1]) + x = fusible_matmul(x, y[1][1]) return x np.testing.assert_allclose( @@ -358,7 +358,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[i][j]) + x = fusible_matmul(x, y[i][j]) return x for i in range(2): @@ -376,7 +376,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[2][i, j]) + x = fusible_matmul(x, y[2][i, j]) return x for i in range(2): @@ -397,7 +397,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j, k): - x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype) + x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype) return x + b[i, j] @jit_no_excess_precision @@ -428,7 +428,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jax.jit @@ -454,7 +454,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -482,7 +482,7 @@ class FusableMatmulTest(jtu.JaxTestCase): def matmul_concat(x, ys, y3): y = jnp.concatenate(ys, axis=0) y = jnp.concatenate([y, y3], axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -509,7 +509,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -534,7 +534,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=1)[1] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -559,7 +559,7 @@ class FusableMatmulTest(jtu.JaxTestCase): @fuser.fuse def matmul_concat(x, y1, y2, i, j): y = jnp.concatenate([y1, y2[i]], axis=1)[j] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -585,7 +585,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -607,7 +607,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -629,7 +629,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -651,7 +651,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -673,7 +673,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -695,7 +695,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -716,7 +716,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -738,7 +738,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -760,7 +760,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -782,7 +782,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z.T impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -804,7 +804,7 @@ class FusableMatmulTest(jtu.JaxTestCase): return z.T * 2 impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -867,7 +867,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase): impl = fuser.fuse( functools.partial( matmul, - fusable_matmul, + fusible_matmul, ) ) ref = functools.partial(matmul, dot_ref) @@ -893,7 +893,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase): out_ref = jit_no_excess_precision(ref)(x, y) - impl = fuser.fuse(functools.partial(matmul, fusable_matmul)) + impl = fuser.fuse(functools.partial(matmul, fusible_matmul)) out = jax.jit(impl)(x, y) self.assertAllClose(out, out_ref, atol=0) @@ -917,7 +917,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase): impl = fuser.fuse( functools.partial( matmul, - functools.partial(fusable_matmul, bk=256, bn=128), + functools.partial(fusible_matmul, bk=256, bn=128), ) ) out = jax.jit(impl)(x, y) @@ -953,7 +953,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -990,7 +990,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -1025,7 +1025,7 @@ class ExcessPrecisionTest(jtu.JaxTestCase): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn,