Merge pull request #22605 from Cjkkkk:add_sm86_sm89_flash_attention

PiperOrigin-RevId: 655894058
This commit is contained in:
jax authors 2024-07-25 03:32:45 -07:00
commit 2dadbd7eb6
2 changed files with 16 additions and 12 deletions

View File

@ -24,6 +24,7 @@ from jax._src import dispatch
from jax._src.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src.lib import cuda_versions
from jax._src import xla_bridge
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters.mlir import hlo
@ -308,13 +309,13 @@ def check_cudnn_version():
raise RuntimeError("cuDNN is not detected.")
return cuda_versions.cudnn_get_version()
def check_compute_capability(cc):
if cuda_versions is None:
raise RuntimeError("cuDNN is not detected.")
for i in range(jax.device_count()):
compute_cap = cuda_versions.cuda_compute_capability(i)
if compute_cap not in cc:
raise RuntimeError("Require compute capability in " + str(cc))
def check_compute_capability(capability):
if not 'cuda' in xla_bridge.get_backend().platform_version:
return False
d, *_ = jax.local_devices(backend="gpu")
target = tuple(int(x) for x in capability.split("."))
current = tuple(int(x) for x in d.compute_capability.split("."))
return current >= target
def _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
@ -986,8 +987,9 @@ def dot_product_attention(query: Array,
"""
# check if cuDNN is installed
cudnn_version = check_cudnn_version()
# only support Ampere and Hopper for now
check_compute_capability((80, 90))
# only support at least Ampere
if not check_compute_capability("8.0"):
raise RuntimeError("Require at least Ampere arch to run")
layout = _normalize_layout(qkv_layout)
if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None):
raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask")

View File

@ -30,7 +30,6 @@ from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention,
check_is_flash_attention,
check_cudnn_version,
check_compute_capability,
get_large_negative_number,
MaskType,
AttentionLayout,
@ -159,12 +158,13 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.skipTest("Requires more than 4 devices.")
try:
cudnn_version = check_cudnn_version()
check_compute_capability((80, 90))
except RuntimeError as e:
self.skipTest(str(e))
return
if cudnn_version < 8904:
self.skipTest("Requires >= cuDNN 8.9.4")
if not jtu.is_cuda_compute_capability_at_least("8.0"):
self.skipTest("Requires at least Ampere arch")
@jtu.sample_product(
batch_size=[4],
@ -340,12 +340,14 @@ class DotProductAttentionTest(jtu.JaxTestCase):
def test_sdpa_broadcast_bias_and_dbias(self):
try:
cudnn_version = check_cudnn_version()
check_compute_capability((90,))
except RuntimeError as e:
self.skipTest(str(e))
return
if cudnn_version < 8906:
self.skipTest("Requires >= cuDNN 8.9.6")
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Requires at least Hopper arch")
k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5)
query = jax.random.normal(
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)