Add fuser to jax.experimental.pallas

Note that fuser is considered experimental within Pallas and APIs are subject to change

PiperOrigin-RevId: 733117882
This commit is contained in:
Sharad Vikram 2025-03-03 17:25:59 -08:00 committed by jax authors
parent 0b6c355083
commit d32e282ff9
4 changed files with 48 additions and 1 deletions

View File

@ -29,6 +29,7 @@ load(
"jax_visibility",
"mosaic_gpu_internal_users",
"mosaic_internal_users",
"pallas_fuser_users",
"pallas_gpu_internal_users",
"pallas_tpu_internal_users",
"py_deps",
@ -105,6 +106,12 @@ package_group(
packages = pallas_tpu_internal_users,
)
package_group(
name = "pallas_fuser_users",
includes = [":internal"],
packages = pallas_fuser_users,
)
package_group(
name = "mosaic_gpu_users",
includes = [":internal"],
@ -628,6 +635,7 @@ pytype_strict_library(
"experimental/pallas/ops/gpu/**/*.py",
"experimental/pallas/ops/tpu/**/*.py",
"experimental/pallas/tpu.py",
"experimental/pallas/fuser.py",
"experimental/pallas/triton.py",
],
),
@ -664,6 +672,21 @@ pytype_strict_library(
],
)
pytype_strict_library(
name = "pallas_fuser",
srcs = ["experimental/pallas/fuser.py"],
visibility = [
":pallas_fuser_users",
],
deps = [
":pallas", # build_cleaner: keep
"//jax/_src/pallas/fuser:block_spec",
"//jax/_src/pallas/fuser:fusable",
"//jax/_src/pallas/fuser:fusion",
"//jax/_src/pallas/fuser:jaxpr_fusion",
],
)
pytype_strict_library(
name = "pallas_gpu_ops",
srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"],

View File

@ -35,7 +35,7 @@ class Fusion(Generic[A, K]):
in_type: tuple[tuple[Any, ...], dict[str, Any]]
out_type: Any
def __call__(self, *args: A.args, **kwargs: A.kwargs):
def __call__(self, *args: A.args, **kwargs: A.kwargs) -> K:
return self.func(*args, **kwargs)
@property

View File

@ -0,0 +1,23 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for fuser."""
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.fusable import fusable as fusable
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

View File

@ -46,6 +46,7 @@ mosaic_gpu_internal_users = []
mosaic_internal_users = []
pallas_gpu_internal_users = []
pallas_tpu_internal_users = []
pallas_fuser_users = []
mosaic_extension_deps = []
jax_internal_export_back_compat_test_util_visibility = []