mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions.
This commit is contained in:
parent
ea5fd29b90
commit
85d792a92d
@ -11,3 +11,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .fusion import cudnn_fusion
|
||||
|
91
jax/_src/cudnn/fusion.py
Normal file
91
jax/_src/cudnn/fusion.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters.mlir import hlo
|
||||
from jax.interpreters.mlir import ir
|
||||
|
||||
|
||||
|
||||
def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs):
|
||||
del unused_kwargs
|
||||
return jax_core.jaxpr_as_fun(jaxpr)(*args)
|
||||
|
||||
|
||||
def _custom_abstract_eval(*args, jaxpr, **unused_kwargs):
|
||||
del unused_kwargs
|
||||
del args
|
||||
return jaxpr.out_avals
|
||||
|
||||
|
||||
cudnn_fusion_p = jax_core.Primitive("cudnn_fusion")
|
||||
cudnn_fusion_p.multiple_results = True
|
||||
cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval)
|
||||
cudnn_fusion_p.def_impl(_cudnn_fusion_impl)
|
||||
|
||||
|
||||
def call_cudnn_fusion(f, *args, **kwargs):
|
||||
"""Creates a new cudnn_fusion corresponding to calling
|
||||
the given function f with args and kwargs."""
|
||||
jaxpr, out_shapes = jax.make_jaxpr(
|
||||
functools.partial(f, **kwargs), return_shape=True
|
||||
)(*args)
|
||||
flat_args = jax.tree.leaves(args)
|
||||
out_tree = jax.tree.structure(out_shapes)
|
||||
out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr)
|
||||
return jax.tree.unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
def _cudnn_fusion_stablehlo_lowering(
|
||||
ctx,
|
||||
*args,
|
||||
name,
|
||||
jaxpr,
|
||||
):
|
||||
"""Make cudnn_fusion which calls the implementation function.
|
||||
Currently this leaks a CallOp since we're using the `core_call_lowering`
|
||||
function, but this should get cleaned up by DCE easily.
|
||||
"""
|
||||
impl = mlir.core_call_lowering(
|
||||
ctx, *args, name=name + ".impl", call_jaxpr=jaxpr
|
||||
)
|
||||
call_op = impl[0].owner
|
||||
called_fn = call_op.attributes["callee"]
|
||||
cudnn_fusion = hlo.CustomCallOp(
|
||||
[r.type for r in call_op.results],
|
||||
call_op.operands,
|
||||
call_target_name="__cudnn$fusion",
|
||||
called_computations=ir.ArrayAttr.get([called_fn]),
|
||||
)
|
||||
return cudnn_fusion.results
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda"
|
||||
)
|
||||
|
||||
|
||||
def cudnn_fusion(f):
|
||||
"""Makes a function become a cuDNN kernel. Relies on XLA's handling of
|
||||
custom fusions with __cudnn$fusion backend. Currently limited to GEMM
|
||||
fusions. For example - batch matmul with mixed types and addition:
|
||||
|
||||
@cudnn_fusion
|
||||
def fn(x, y, z):
|
||||
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
|
||||
"""
|
||||
return functools.partial(call_cudnn_fusion, f)
|
14
tests/BUILD
14
tests/BUILD
@ -1523,6 +1523,20 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "cudnn_fusion_test",
|
||||
srcs = ["cudnn_fusion_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
69
tests/cudnn_fusion_test.py
Normal file
69
tests/cudnn_fusion_test.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
from unittest import SkipTest
|
||||
from jax._src import test_util as jtu
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.cudnn import cudnn_fusion
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class CudnnFusionTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if (not jtu.test_device_matches(["cuda"]) or
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
self.skipTest("Only works on >= sm80 GPUs")
|
||||
super().setUp()
|
||||
|
||||
@parameterized.parameters(["", "pmap"])
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_cudnn_fusion(self, mode):
|
||||
batch_size = 2
|
||||
if mode == "pmap" and jax.device_count() < batch_size:
|
||||
raise SkipTest("pmap test requires 2 GPUs")
|
||||
|
||||
@cudnn_fusion
|
||||
def comp1(x, y, z):
|
||||
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
|
||||
|
||||
k = jax.random.key(0)
|
||||
s = batch_size, 16, 16
|
||||
x = jnp.int8(jax.random.normal(k, shape=s))
|
||||
y = jnp.bfloat16(jax.random.normal(k, shape=s))
|
||||
z = jnp.float32(jax.random.normal(k, shape=s))
|
||||
|
||||
fn = jax.pmap(comp1) if mode == "pmap" else comp1
|
||||
jitted = jax.jit(comp1)
|
||||
lowered = jitted.lower(x, y, z)
|
||||
stablehlo = lowered.as_text("stablehlo")
|
||||
self.assertIn("func.func private @comp1", stablehlo)
|
||||
self.assertIn("__cudnn$fusion", stablehlo)
|
||||
|
||||
hlo = lowered.as_text("hlo")
|
||||
self.assertIn('custom_call_target="__cudnn$fusion"', hlo)
|
||||
self.assertIn("called_computations=", hlo)
|
||||
|
||||
hlo_after_opt = lowered.compile().as_text()
|
||||
self.assertIn("kind=kCustom", hlo_after_opt)
|
||||
self.assertIn("plan_id", hlo_after_opt)
|
||||
|
||||
self.assertAllClose(jitted(x, y, z), fn(x, y, z))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user