rename tests with more descriptive name & Unify SDPA API

This commit is contained in:
Cjkkkk 2024-01-18 22:54:23 -08:00
parent 40eb11bc79
commit 49f1537f98
2 changed files with 56 additions and 77 deletions

View File

@ -155,29 +155,30 @@ def create_dot_product_attention_backend_config(batch,
backend_config = json.dumps(backend_config)
return backend_config
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
index = is_bwd << 3 | has_dropout << 2 | has_mask << 1 | has_bias
_custom_name_maps = [
# fMHA forward call targets.
"__cudnn$fhmaSoftmax",
"__cudnn$fhmaScaleBiasSoftmax",
"__cudnn$fhmaScaleMaskSoftmax",
"__cudnn$fhmaScaleBiasMaskSoftmax",
"__cudnn$fhmaSoftmaxDropout",
"__cudnn$fhmaScaleBiasSoftmaxDropout",
"__cudnn$fhmaScaleMaskSoftmaxDropout",
"__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
# mapping from (is_bwd, has_dropout, has_mask, has_bias) to custom call name
_custom_name_maps = {
# fMHA forward call targets.
(False, False, False, False): "__cudnn$fhmaSoftmax",
(False, False, False, True): "__cudnn$fhmaScaleBiasSoftmax",
(False, False, True, False): "__cudnn$fhmaScaleMaskSoftmax",
(False, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmax",
(False, True, False, False): "__cudnn$fhmaSoftmaxDropout",
(False, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropout",
(False, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropout",
(False, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
# fMHA backward call targets.
"__cudnn$fhmaSoftmaxBackward",
"__cudnn$fhmaScaleBiasSoftmaxBackward",
"__cudnn$fhmaScaleMaskSoftmaxBackward",
"__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
"__cudnn$fhmaSoftmaxDropoutBackward",
"__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
"__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
"__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
]
return _custom_name_maps[index]
(True, False, False, False): "__cudnn$fhmaSoftmaxBackward",
(True, False, False, True): "__cudnn$fhmaScaleBiasSoftmaxBackward",
(True, False, True, False): "__cudnn$fhmaScaleMaskSoftmaxBackward",
(True, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
(True, True, False, False): "__cudnn$fhmaSoftmaxDropoutBackward",
(True, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
(True, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
(True, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
}
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
return _custom_name_maps[(is_bwd, has_dropout, has_mask, has_bias)]
def check_qkv_layout(query, key, value):
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
@ -673,9 +674,9 @@ _dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_atte
def dot_product_attention(query: Array,
key: Array,
value: Array,
scale: float = 1.0,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
scale: float = 1.0,
is_causal_mask: bool = False,
seed: int = 42,
dropout_rate: float = 0.):
@ -709,7 +710,8 @@ def dot_product_attention(query: Array,
is_flash_attention, is_cross_attention = check_is_flash_attention(query, key)
# check if cuDNN is installed and if cuDNN version is sufficient
check_cudnn_version(is_flash_attention, is_cross_attention)
if mask is not None and is_causal_mask:
raise ValueError("can not apply a mask and generate a causal_mask at the same time.")
variadic_args = (bias is not None, mask is not None)
if bias is None:
bias = jnp.zeros(0, dtype=query.dtype)

View File

@ -16,7 +16,7 @@ from functools import partial
from absl.testing import absltest
from typing import Any, Optional
import os
os.environ['XLA_FLAGS'] = '--xla_dump_disable_metadata --xla_gpu_enable_triton_gemm=false --xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_dump_hlo_module_re=.*pjit__unnamed_function.* --xla_dump_hlo_pass_re=.* --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
import numpy as np
import jax
@ -30,72 +30,49 @@ from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
config.parse_flags_with_absl()
Array = jnp.ndarray
def f(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
dropout_rate: float = 0.1) -> Array:
output = dot_product_attention(
query,
key,
value,
scale=scale,
bias=bias,
mask=mask,
is_causal_mask=causal_mask,
dropout_rate=dropout_rate)
return output
def f_train(query: Array,
def sdpa_train(query: Array,
key: Array,
value: Array,
grad: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
is_causal_mask: bool = False,
dropout_rate: float = 0.1) -> Array:
out, f_vjp = jax.vjp(
partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
out, sdpa_vjp = jax.vjp(
partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
query, key, value, bias, None)
query_grad, key_grad, value_grad, _, _ = f_vjp(grad)
query_grad, key_grad, value_grad, _, _ = sdpa_vjp(grad)
return out, (query_grad, key_grad, value_grad)
def g(query: Array,
def sdpa_ref(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
is_causal_mask: bool = False,
dropout_rate: float = 0.1) -> Array:
def get_large_negative_number(dtype):
if jnp.issubdtype(dtype, jnp.inexact):
dtype_max = jnp.finfo(dtype).max
elif jnp.issubdtype(dtype, jnp.integer):
dtype_max = jnp.iinfo(dtype).max
else:
raise ValueError('Unsupported dtype for inputs.')
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
def get_causal_mask(input_t):
large_negative_number = get_large_negative_number(input_t.dtype)
dtype = input_t.dtype
if jnp.issubdtype(dtype, jnp.inexact):
dtype_max = jnp.finfo(dtype).max
elif jnp.issubdtype(dtype, jnp.integer):
dtype_max = jnp.iinfo(dtype).max
else:
raise ValueError('Unsupported dtype for inputs.')
large_negative_number = jnp.asarray(-0.7 * dtype_max, dtype=dtype)
t = input_t.shape[2]
col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
col_idx = jax.lax.broadcasted_iota(np.int32, (t, t), 1)
row_idx = jax.lax.broadcasted_iota(np.int32, (t, t), 0)
mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
return mask[jnp.newaxis, jnp.newaxis, :, :]
if scale != 1.0:
query = query * scale
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
if causal_mask:
if is_causal_mask:
bias = get_causal_mask(attn_weights)
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
@ -108,19 +85,19 @@ def g(query: Array,
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
def g_train(query: Array,
def sdpa_train_ref(query: Array,
key: Array,
value: Array,
grad: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
is_causal_mask: bool = False,
dropout_rate: float = 0.1) -> Array:
out_ref, g_vjp = jax.vjp(
partial(g, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
out_ref, sdpa_vjp_ref = jax.vjp(
partial(sdpa_ref, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
query, key, value, bias, None)
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = g_vjp(grad)
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = sdpa_vjp_ref(grad)
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
class DotProductAttentionTest(jtu.JaxTestCase):
@ -161,8 +138,6 @@ class DotProductAttentionTest(jtu.JaxTestCase):
else:
bias = None
jitted_f_train = jax.jit(partial(f_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
jitted_g_train = jax.jit(partial(g_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ('dp', 'tp')) as mesh:
@ -182,18 +157,20 @@ class DotProductAttentionTest(jtu.JaxTestCase):
grad = jax.device_put(grad, qkv_sharding)
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, replicated)
out_shardings = (replicated, (qkv_sharding, qkv_sharding, qkv_sharding))
pjitted_f_train = jax.jit(jitted_f_train,
jitted_sdpa_train = jax.jit(
partial(sdpa_train, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
in_shardings=in_shardings,
out_shardings=out_shardings
)
pjitted_g_train = jax.jit(jitted_g_train,
jitted_sdpa_train_ref = jax.jit(
partial(sdpa_train_ref, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
in_shardings=in_shardings,
out_shardings=out_shardings
)
out, (query_grad, key_grad, value_grad) = pjitted_f_train(query, key, value, grad, bias, None)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
out, (query_grad, key_grad, value_grad) = jitted_sdpa_train(query, key, value, grad, bias, None)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = jitted_sdpa_train_ref(query, key, value, grad, bias, None)
self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
if seq_len > 512:
# query_grad in flash attention is not deterministic