mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add fuser to jax.experimental.pallas
Note that fuser is considered experimental within Pallas and APIs are subject to change PiperOrigin-RevId: 733117882
This commit is contained in:
parent
0b6c355083
commit
d32e282ff9
23
jax/BUILD
23
jax/BUILD
@ -29,6 +29,7 @@ load(
|
||||
"jax_visibility",
|
||||
"mosaic_gpu_internal_users",
|
||||
"mosaic_internal_users",
|
||||
"pallas_fuser_users",
|
||||
"pallas_gpu_internal_users",
|
||||
"pallas_tpu_internal_users",
|
||||
"py_deps",
|
||||
@ -105,6 +106,12 @@ package_group(
|
||||
packages = pallas_tpu_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "pallas_fuser_users",
|
||||
includes = [":internal"],
|
||||
packages = pallas_fuser_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "mosaic_gpu_users",
|
||||
includes = [":internal"],
|
||||
@ -628,6 +635,7 @@ pytype_strict_library(
|
||||
"experimental/pallas/ops/gpu/**/*.py",
|
||||
"experimental/pallas/ops/tpu/**/*.py",
|
||||
"experimental/pallas/tpu.py",
|
||||
"experimental/pallas/fuser.py",
|
||||
"experimental/pallas/triton.py",
|
||||
],
|
||||
),
|
||||
@ -664,6 +672,21 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_fuser",
|
||||
srcs = ["experimental/pallas/fuser.py"],
|
||||
visibility = [
|
||||
":pallas_fuser_users",
|
||||
],
|
||||
deps = [
|
||||
":pallas", # build_cleaner: keep
|
||||
"//jax/_src/pallas/fuser:block_spec",
|
||||
"//jax/_src/pallas/fuser:fusable",
|
||||
"//jax/_src/pallas/fuser:fusion",
|
||||
"//jax/_src/pallas/fuser:jaxpr_fusion",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_gpu_ops",
|
||||
srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"],
|
||||
|
@ -35,7 +35,7 @@ class Fusion(Generic[A, K]):
|
||||
in_type: tuple[tuple[Any, ...], dict[str, Any]]
|
||||
out_type: Any
|
||||
|
||||
def __call__(self, *args: A.args, **kwargs: A.kwargs):
|
||||
def __call__(self, *args: A.args, **kwargs: A.kwargs) -> K:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
|
23
jax/experimental/pallas/fuser.py
Normal file
23
jax/experimental/pallas/fuser.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Public API for fuser."""
|
||||
|
||||
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
|
||||
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.fusable import fusable as fusable
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
@ -46,6 +46,7 @@ mosaic_gpu_internal_users = []
|
||||
mosaic_internal_users = []
|
||||
pallas_gpu_internal_users = []
|
||||
pallas_tpu_internal_users = []
|
||||
pallas_fuser_users = []
|
||||
mosaic_extension_deps = []
|
||||
|
||||
jax_internal_export_back_compat_test_util_visibility = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user