mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22605 from Cjkkkk:add_sm86_sm89_flash_attention
PiperOrigin-RevId: 655894058
This commit is contained in:
commit
2dadbd7eb6
@ -24,6 +24,7 @@ from jax._src import dispatch
|
||||
from jax._src.custom_partitioning import custom_partitioning
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src import xla_bridge
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.mlir import hlo
|
||||
@ -308,13 +309,13 @@ def check_cudnn_version():
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
return cuda_versions.cudnn_get_version()
|
||||
|
||||
def check_compute_capability(cc):
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
for i in range(jax.device_count()):
|
||||
compute_cap = cuda_versions.cuda_compute_capability(i)
|
||||
if compute_cap not in cc:
|
||||
raise RuntimeError("Require compute capability in " + str(cc))
|
||||
def check_compute_capability(capability):
|
||||
if not 'cuda' in xla_bridge.get_backend().platform_version:
|
||||
return False
|
||||
d, *_ = jax.local_devices(backend="gpu")
|
||||
target = tuple(int(x) for x in capability.split("."))
|
||||
current = tuple(int(x) for x in d.compute_capability.split("."))
|
||||
return current >= target
|
||||
|
||||
def _dot_product_attention_fwd(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
@ -986,8 +987,9 @@ def dot_product_attention(query: Array,
|
||||
"""
|
||||
# check if cuDNN is installed
|
||||
cudnn_version = check_cudnn_version()
|
||||
# only support Ampere and Hopper for now
|
||||
check_compute_capability((80, 90))
|
||||
# only support at least Ampere
|
||||
if not check_compute_capability("8.0"):
|
||||
raise RuntimeError("Require at least Ampere arch to run")
|
||||
layout = _normalize_layout(qkv_layout)
|
||||
if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None):
|
||||
raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask")
|
||||
|
@ -30,7 +30,6 @@ from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
dot_product_attention,
|
||||
check_is_flash_attention,
|
||||
check_cudnn_version,
|
||||
check_compute_capability,
|
||||
get_large_negative_number,
|
||||
MaskType,
|
||||
AttentionLayout,
|
||||
@ -159,12 +158,13 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
check_compute_capability((80, 90))
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if cudnn_version < 8904:
|
||||
self.skipTest("Requires >= cuDNN 8.9.4")
|
||||
if not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
self.skipTest("Requires at least Ampere arch")
|
||||
|
||||
@jtu.sample_product(
|
||||
batch_size=[4],
|
||||
@ -340,12 +340,14 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
def test_sdpa_broadcast_bias_and_dbias(self):
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
check_compute_capability((90,))
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if cudnn_version < 8906:
|
||||
self.skipTest("Requires >= cuDNN 8.9.6")
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Requires at least Hopper arch")
|
||||
|
||||
k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
|
Loading…
x
Reference in New Issue
Block a user