mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import functools
|
|
import jax
|
|
from jax._src import core as jax_core
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters.mlir import hlo
|
|
from jax.interpreters.mlir import ir
|
|
|
|
|
|
|
|
def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs):
|
|
del unused_kwargs
|
|
return jax_core.jaxpr_as_fun(jaxpr)(*args)
|
|
|
|
|
|
def _custom_abstract_eval(*args, jaxpr, **unused_kwargs):
|
|
del unused_kwargs
|
|
del args
|
|
return jaxpr.out_avals
|
|
|
|
|
|
cudnn_fusion_p = jax_core.Primitive("cudnn_fusion")
|
|
cudnn_fusion_p.multiple_results = True
|
|
cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval)
|
|
cudnn_fusion_p.def_impl(_cudnn_fusion_impl)
|
|
|
|
|
|
def call_cudnn_fusion(f, *args, **kwargs):
|
|
"""Creates a new cudnn_fusion corresponding to calling
|
|
the given function f with args and kwargs."""
|
|
jaxpr, out_shapes = jax.make_jaxpr(
|
|
functools.partial(f, **kwargs), return_shape=True
|
|
)(*args)
|
|
flat_args = jax.tree.leaves(args)
|
|
out_tree = jax.tree.structure(out_shapes)
|
|
out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr)
|
|
return jax.tree.unflatten(out_tree, out_flat)
|
|
|
|
|
|
def _cudnn_fusion_stablehlo_lowering(
|
|
ctx,
|
|
*args,
|
|
name,
|
|
jaxpr,
|
|
):
|
|
"""Make cudnn_fusion which calls the implementation function.
|
|
Currently this leaks a CallOp since we're using the `core_call_lowering`
|
|
function, but this should get cleaned up by DCE easily.
|
|
"""
|
|
impl = mlir.core_call_lowering(
|
|
ctx, *args, name=name + ".impl", call_jaxpr=jaxpr
|
|
)
|
|
call_op = impl[0].owner
|
|
called_fn = call_op.attributes["callee"]
|
|
cudnn_fusion = hlo.CustomCallOp(
|
|
[r.type for r in call_op.results],
|
|
call_op.operands,
|
|
call_target_name="__cudnn$fusion",
|
|
called_computations=ir.ArrayAttr.get([called_fn]),
|
|
)
|
|
return cudnn_fusion.results
|
|
|
|
|
|
mlir.register_lowering(
|
|
cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda"
|
|
)
|
|
|
|
|
|
def cudnn_fusion(f):
|
|
"""Makes a function become a cuDNN kernel. Relies on XLA's handling of
|
|
custom fusions with __cudnn$fusion backend. Currently limited to GEMM
|
|
fusions. For example - batch matmul with mixed types and addition:
|
|
|
|
@cudnn_fusion
|
|
def fn(x, y, z):
|
|
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
|
|
"""
|
|
return functools.partial(call_cudnn_fusion, f)
|