rocm_jax/jax/_src/cudnn/fusion.py
2024-12-10 11:11:32 -08:00

92 lines
2.7 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import jax
from jax._src import core as jax_core
from jax.interpreters import mlir
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir
def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs):
del unused_kwargs
return jax_core.jaxpr_as_fun(jaxpr)(*args)
def _custom_abstract_eval(*args, jaxpr, **unused_kwargs):
del unused_kwargs
del args
return jaxpr.out_avals
cudnn_fusion_p = jax_core.Primitive("cudnn_fusion")
cudnn_fusion_p.multiple_results = True
cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval)
cudnn_fusion_p.def_impl(_cudnn_fusion_impl)
def call_cudnn_fusion(f, *args, **kwargs):
"""Creates a new cudnn_fusion corresponding to calling
the given function f with args and kwargs."""
jaxpr, out_shapes = jax.make_jaxpr(
functools.partial(f, **kwargs), return_shape=True
)(*args)
flat_args = jax.tree.leaves(args)
out_tree = jax.tree.structure(out_shapes)
out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr)
return jax.tree.unflatten(out_tree, out_flat)
def _cudnn_fusion_stablehlo_lowering(
ctx,
*args,
name,
jaxpr,
):
"""Make cudnn_fusion which calls the implementation function.
Currently this leaks a CallOp since we're using the `core_call_lowering`
function, but this should get cleaned up by DCE easily.
"""
impl = mlir.core_call_lowering(
ctx, *args, name=name + ".impl", call_jaxpr=jaxpr
)
call_op = impl[0].owner
called_fn = call_op.attributes["callee"]
cudnn_fusion = hlo.CustomCallOp(
[r.type for r in call_op.results],
call_op.operands,
call_target_name="__cudnn$fusion",
called_computations=ir.ArrayAttr.get([called_fn]),
)
return cudnn_fusion.results
mlir.register_lowering(
cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda"
)
def cudnn_fusion(f):
"""Makes a function become a cuDNN kernel. Relies on XLA's handling of
custom fusions with __cudnn$fusion backend. Currently limited to GEMM
fusions. For example - batch matmul with mixed types and addition:
@cudnn_fusion
def fn(x, y, z):
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
"""
return functools.partial(call_cudnn_fusion, f)