mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
rename tests with more descriptive name & Unify SDPA API
This commit is contained in:
parent
40eb11bc79
commit
49f1537f98
@ -155,29 +155,30 @@ def create_dot_product_attention_backend_config(batch,
|
||||
backend_config = json.dumps(backend_config)
|
||||
return backend_config
|
||||
|
||||
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
index = is_bwd << 3 | has_dropout << 2 | has_mask << 1 | has_bias
|
||||
_custom_name_maps = [
|
||||
# fMHA forward call targets.
|
||||
"__cudnn$fhmaSoftmax",
|
||||
"__cudnn$fhmaScaleBiasSoftmax",
|
||||
"__cudnn$fhmaScaleMaskSoftmax",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmax",
|
||||
"__cudnn$fhmaSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
|
||||
# mapping from (is_bwd, has_dropout, has_mask, has_bias) to custom call name
|
||||
_custom_name_maps = {
|
||||
# fMHA forward call targets.
|
||||
(False, False, False, False): "__cudnn$fhmaSoftmax",
|
||||
(False, False, False, True): "__cudnn$fhmaScaleBiasSoftmax",
|
||||
(False, False, True, False): "__cudnn$fhmaScaleMaskSoftmax",
|
||||
(False, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmax",
|
||||
(False, True, False, False): "__cudnn$fhmaSoftmaxDropout",
|
||||
(False, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropout",
|
||||
(False, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropout",
|
||||
(False, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
|
||||
# fMHA backward call targets.
|
||||
"__cudnn$fhmaSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
|
||||
"__cudnn$fhmaSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
|
||||
]
|
||||
return _custom_name_maps[index]
|
||||
(True, False, False, False): "__cudnn$fhmaSoftmaxBackward",
|
||||
(True, False, False, True): "__cudnn$fhmaScaleBiasSoftmaxBackward",
|
||||
(True, False, True, False): "__cudnn$fhmaScaleMaskSoftmaxBackward",
|
||||
(True, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
|
||||
(True, True, False, False): "__cudnn$fhmaSoftmaxDropoutBackward",
|
||||
(True, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
|
||||
(True, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
|
||||
(True, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
|
||||
}
|
||||
|
||||
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
return _custom_name_maps[(is_bwd, has_dropout, has_mask, has_bias)]
|
||||
|
||||
def check_qkv_layout(query, key, value):
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
|
||||
@ -673,9 +674,9 @@ _dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_atte
|
||||
def dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
scale: float = 1.0,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
scale: float = 1.0,
|
||||
is_causal_mask: bool = False,
|
||||
seed: int = 42,
|
||||
dropout_rate: float = 0.):
|
||||
@ -709,7 +710,8 @@ def dot_product_attention(query: Array,
|
||||
is_flash_attention, is_cross_attention = check_is_flash_attention(query, key)
|
||||
# check if cuDNN is installed and if cuDNN version is sufficient
|
||||
check_cudnn_version(is_flash_attention, is_cross_attention)
|
||||
|
||||
if mask is not None and is_causal_mask:
|
||||
raise ValueError("can not apply a mask and generate a causal_mask at the same time.")
|
||||
variadic_args = (bias is not None, mask is not None)
|
||||
if bias is None:
|
||||
bias = jnp.zeros(0, dtype=query.dtype)
|
||||
|
@ -16,7 +16,7 @@ from functools import partial
|
||||
from absl.testing import absltest
|
||||
from typing import Any, Optional
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = '--xla_dump_disable_metadata --xla_gpu_enable_triton_gemm=false --xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_dump_hlo_module_re=.*pjit__unnamed_function.* --xla_dump_hlo_pass_re=.* --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -30,72 +30,49 @@ from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
|
||||
config.parse_flags_with_absl()
|
||||
Array = jnp.ndarray
|
||||
|
||||
def f(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
output = dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale=scale,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
is_causal_mask=causal_mask,
|
||||
dropout_rate=dropout_rate)
|
||||
return output
|
||||
|
||||
def f_train(query: Array,
|
||||
def sdpa_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
is_causal_mask: bool = False,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
out, f_vjp = jax.vjp(
|
||||
partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
|
||||
out, sdpa_vjp = jax.vjp(
|
||||
partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
|
||||
query, key, value, bias, None)
|
||||
query_grad, key_grad, value_grad, _, _ = f_vjp(grad)
|
||||
query_grad, key_grad, value_grad, _, _ = sdpa_vjp(grad)
|
||||
return out, (query_grad, key_grad, value_grad)
|
||||
|
||||
def g(query: Array,
|
||||
def sdpa_ref(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
is_causal_mask: bool = False,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
def get_large_negative_number(dtype):
|
||||
if jnp.issubdtype(dtype, jnp.inexact):
|
||||
dtype_max = jnp.finfo(dtype).max
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
dtype_max = jnp.iinfo(dtype).max
|
||||
else:
|
||||
raise ValueError('Unsupported dtype for inputs.')
|
||||
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
|
||||
|
||||
def get_causal_mask(input_t):
|
||||
large_negative_number = get_large_negative_number(input_t.dtype)
|
||||
dtype = input_t.dtype
|
||||
if jnp.issubdtype(dtype, jnp.inexact):
|
||||
dtype_max = jnp.finfo(dtype).max
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
dtype_max = jnp.iinfo(dtype).max
|
||||
else:
|
||||
raise ValueError('Unsupported dtype for inputs.')
|
||||
large_negative_number = jnp.asarray(-0.7 * dtype_max, dtype=dtype)
|
||||
t = input_t.shape[2]
|
||||
col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
|
||||
row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
|
||||
col_idx = jax.lax.broadcasted_iota(np.int32, (t, t), 1)
|
||||
row_idx = jax.lax.broadcasted_iota(np.int32, (t, t), 0)
|
||||
mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
|
||||
return mask[jnp.newaxis, jnp.newaxis, :, :]
|
||||
|
||||
if scale != 1.0:
|
||||
query = query * scale
|
||||
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
|
||||
if causal_mask:
|
||||
if is_causal_mask:
|
||||
bias = get_causal_mask(attn_weights)
|
||||
if bias is not None:
|
||||
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
|
||||
@ -108,19 +85,19 @@ def g(query: Array,
|
||||
|
||||
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
|
||||
|
||||
def g_train(query: Array,
|
||||
def sdpa_train_ref(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
is_causal_mask: bool = False,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
out_ref, g_vjp = jax.vjp(
|
||||
partial(g, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
|
||||
out_ref, sdpa_vjp_ref = jax.vjp(
|
||||
partial(sdpa_ref, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
|
||||
query, key, value, bias, None)
|
||||
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = g_vjp(grad)
|
||||
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = sdpa_vjp_ref(grad)
|
||||
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
|
||||
|
||||
class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
@ -161,8 +138,6 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
else:
|
||||
bias = None
|
||||
|
||||
jitted_f_train = jax.jit(partial(f_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
|
||||
jitted_g_train = jax.jit(partial(g_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ('dp', 'tp')) as mesh:
|
||||
@ -182,18 +157,20 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
grad = jax.device_put(grad, qkv_sharding)
|
||||
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, replicated)
|
||||
out_shardings = (replicated, (qkv_sharding, qkv_sharding, qkv_sharding))
|
||||
pjitted_f_train = jax.jit(jitted_f_train,
|
||||
jitted_sdpa_train = jax.jit(
|
||||
partial(sdpa_train, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
pjitted_g_train = jax.jit(jitted_g_train,
|
||||
jitted_sdpa_train_ref = jax.jit(
|
||||
partial(sdpa_train_ref, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
out, (query_grad, key_grad, value_grad) = pjitted_f_train(query, key, value, grad, bias, None)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
|
||||
out, (query_grad, key_grad, value_grad) = jitted_sdpa_train(query, key, value, grad, bias, None)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = jitted_sdpa_train_ref(query, key, value, grad, bias, None)
|
||||
self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
|
||||
if seq_len > 512:
|
||||
# query_grad in flash attention is not deterministic
|
||||
|
Loading…
x
Reference in New Issue
Block a user