add unit test and move to _src/cudnn dir

This commit is contained in:
Cjkkkk 2024-01-02 13:48:53 -08:00
parent d1141b4058
commit 9b8a100039
3 changed files with 263 additions and 142 deletions

View File

@ -32,6 +32,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax._src.interpreters import batching
from jax._src import dispatch
from jax._src.lib import cuda_versions
Array = jnp.ndarray
DType = jnp.dtype
@ -120,6 +121,50 @@ def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
]
return _custom_name_maps[index]
def check_qkv_layout(query, key, value):
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
"query, key and value should have rank 4."
# Only support fp16 and bf16 here
query_dtype = query.dtype
key_dtype = key.dtype
value_dtype = value.dtype
assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \
"query, key and value should have same dtype and should be float16 or bfloat16"
q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape
k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape
v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape
assert (q_batch == k_batch == v_batch) \
and (k_seq_len == v_seq_len) \
and (q_num_heads == k_num_heads == v_num_heads) \
and (q_head_dim == k_head_dim == v_head_dim), \
"query should have layout [batch, q_seq, num_heads, head_dim], " \
"key and value should have layout [batch, kv_seq, num_heads, head_dim]."
def check_is_flash_attention(query, key):
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_sqe_len, _, _ = key.shape
# check if attention pattern is supported by flash attention or fused attention
if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]:
# check if flash attention is supported
is_flash_attention = True
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
# check if regular fused attention is supported
is_flash_attention = False
else:
raise NotImplementedError("Unsupported sequence length and head dim.")
return is_flash_attention
def check_cuDNN_version(is_flash_attention):
# check if cuDNN is installed and if cuDNN version contraint is satisfied
if cuda_versions is None:
raise RuntimeError("cuDNN is not detected.")
elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903:
raise RuntimeError("Require cuDNN at lease 8.9.3 to run flash attention.")
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
raise RuntimeError("Require cuDNN at lease 8.9.1 to run fused attention.")
def _dot_product_attention_fwd(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
output, _ = _dot_product_attention_fwd_p_wrapper.bind(
@ -170,15 +215,6 @@ def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, f
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
key_dtype = dtypes.canonicalize_dtype(key.dtype)
value_dtype = dtypes.canonicalize_dtype(value.dtype)
# Q, K and V must have the same data type
assert query_dtype == key_dtype == value_dtype
# Only support fp16 and bf16 here
assert query_dtype in [jnp.float16, jnp.bfloat16]
# Q, K and V must be 4-D tensors
assert len(query.shape) == len(key.shape) == len(value.shape) == 4
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_seq_len, _, _ = key.shape
output_shape = (batch, q_seq_len, num_heads, head_dim)
@ -201,12 +237,7 @@ def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activatio
query_dtype = dtypes.canonicalize_dtype(query.dtype)
key_dtype = dtypes.canonicalize_dtype(key.dtype)
value_dtype = dtypes.canonicalize_dtype(value.dtype)
# Q, K and V must have the same data type
assert query_dtype == key_dtype == value_dtype
# Only support fp16 and bf16 here
assert query_dtype in [jnp.float16, jnp.bfloat16]
# Q, K and V must be 4-D tensors
assert len(query.shape) == len(key.shape) == len(value.shape) == 4
return (
ShapedArray(
query.shape, query_dtype
@ -587,7 +618,7 @@ def _dot_product_attention(query: Array,
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule)
# User interface
def dot_product_attention(query: Array,
key: Array,
@ -618,10 +649,13 @@ def dot_product_attention(query: Array,
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
batch, q_seq_len, num_heads, head_dim = query.shape
is_flash_attention = False
if q_seq_len > 512:
is_flash_attention = True
# check if query, key and value layout meets cuDNN layout requirement
check_qkv_layout(query, key, value)
# check if flash attention is supported for this attention pattern
is_flash_attention = check_is_flash_attention(query, key)
# check if cuDNN is installed and if cuDNN version is sufficient
check_cuDNN_version(is_flash_attention)
variadic_args = (bias is not None, mask is not None)
if bias is None:
bias = jnp.zeros(0, dtype=query.dtype)

View File

@ -0,0 +1,209 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest
from typing import Any, Optional
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.pjit import pjit
from jax._src import config
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stableHLO import dot_product_attention
config.parse_flags_with_absl()
Array = jnp.ndarray
def f(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
dropout_rate: float = 0.1) -> Array:
output = dot_product_attention(
query,
key,
value,
scale=scale,
bias=bias,
mask=mask,
is_causal_mask=causal_mask,
dropout_rate=dropout_rate)
return output
def f_train(query: Array,
key: Array,
value: Array,
grad: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
dropout_rate: float = 0.1) -> Array:
out, f_vjp = jax.vjp(
partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
query, key, value, bias, None)
query_grad, key_grad, value_grad, _, _ = f_vjp(grad)
return out, (query_grad, key_grad, value_grad)
def g(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
dropout_rate: float = 0.1) -> Array:
def get_large_negative_number(dtype):
if jnp.issubdtype(dtype, jnp.inexact):
dtype_max = jnp.finfo(dtype).max
elif jnp.issubdtype(dtype, jnp.integer):
dtype_max = jnp.iinfo(dtype).max
else:
raise ValueError('Unsupported dtype for inputs.')
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
def get_causal_mask(input_t):
large_negative_number = get_large_negative_number(input_t.dtype)
t = input_t.shape[2]
col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
return mask[jnp.newaxis, jnp.newaxis, :, :]
if scale != 1.0:
query = query * scale
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
if causal_mask:
bias = get_causal_mask(attn_weights)
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
attn_weights = jax.nn.softmax(attn_weights)
if dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
dropout_rng = jax.random.PRNGKey(0)
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
def g_train(query: Array,
key: Array,
value: Array,
grad: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
dropout_rate: float = 0.1) -> Array:
out_ref, g_vjp = jax.vjp(
partial(g, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
query, key, value, bias, None)
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = g_vjp(grad)
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
@jtu.with_config(jax_legacy_prng_key='allow')
class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.sample_product(
batch_size=[4],
seq_len=[256, 1024],
num_heads=[8],
head_dim=[64, 128],
use_bias=[True],
is_causal_mask=[False],
dropout_rate=[0],
scale=[0.5],
dtype=[jnp.float16, jnp.bfloat16]
)
@jtu.run_on_devices("cuda")
def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
head_dim: int, use_bias: bool, is_causal_mask: bool,
dropout_rate: float, scale: float, dtype: jnp.dtype):
if (seq_len == 256 and is_causal_mask):
self.skipTest("Fused attention does not support mask generation.")
if (seq_len == 256 and head_dim == 128):
self.skipTest("Fused attention does not head dim = 128.")
if len(jax.local_devices()) <= 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
os.environ['XLA_FLAGS'] = '--xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
query = jax.random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
key = jax.random.normal(
k2, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
value = jax.random.normal(
k3, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
grad = jax.random.normal(
k4, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
if use_bias:
bias = jax.random.normal(
k5, (batch_size, num_heads, seq_len, seq_len), dtype=dtype)
else:
bias = None
jitted_f_train = jax.jit(partial(f_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
jitted_g_train = jax.jit(partial(g_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ('dp', 'tp')) as mesh:
qkv_spec = PartitionSpec('dp', None, 'tp', None)
if bias is not None:
bias_spec = PartitionSpec('dp', 'tp', None, None)
else:
bias_spec = None
query = jax.device_put(query, NamedSharding(mesh, qkv_spec))
key = jax.device_put(key, NamedSharding(mesh, qkv_spec))
value = jax.device_put(value, NamedSharding(mesh, qkv_spec))
if bias is not None:
bias = jax.device_put(bias, NamedSharding(mesh, bias_spec))
grad = jax.device_put(grad, NamedSharding(mesh, qkv_spec))
in_shardings = (qkv_spec, qkv_spec, qkv_spec, qkv_spec, bias_spec, None)
out_shardings = (None, (qkv_spec, qkv_spec, qkv_spec))
pjitted_f_train = pjit(jitted_f_train,
in_shardings=in_shardings,
out_shardings=out_shardings
)
pjitted_g_train = pjit(jitted_g_train,
in_shardings=in_shardings,
out_shardings=out_shardings
)
out, (query_grad, key_grad, value_grad) = pjitted_g_train(query, key, value, grad, bias, None)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
assert jnp.allclose(out_ref, out, rtol=1e-5, atol=1e-5)
if seq_len > 512:
# query_grad in flash attention is not deterministic
assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2)
else:
assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
assert jnp.allclose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
assert jnp.allclose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1,122 +0,0 @@
import os
import argparse
from functools import partial
from typing import Any, Optional
from flax import linen as nn
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental.fused_attention_stableHLO import dot_product_attention
from jax.sharding import Mesh
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.pjit import pjit
from jax import core, vmap
from jax import make_jaxpr
Array = jnp.ndarray
DType = jnp.dtype
PRNGKey = jnp.ndarray
def f(input_q: Array,
input_k: Array,
input_v: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
causal_mask: bool = False,
scale: float = 0.5,
dropout_rate: float = 0.1) -> Array:
output = dot_product_attention(
input_q,
input_k,
input_v,
scale=scale,
bias=bias,
mask=mask,
is_cauasl_mask=causal_mask,
dropout_rate=dropout_rate)
return output
def train_step(input_q, input_k, input_v, bias, mask, grad, scale, causal_mask, dropout_rate):
out, f_vjp = jax.vjp(
vmap(partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate), in_axes=(0, 0, 0, 0, 0)),
input_q, input_k, input_v, bias, mask
)
input_q_grad, input_k_grad, input_v_grad, bias_grad, _ = f_vjp(grad)
return out, (input_q_grad, input_k_grad, input_v_grad, bias_grad, _)
def main():
parser = argparse.ArgumentParser(description='T5X MHA Unit Test')
parser.add_argument("--batch_size", dest="batch_size", type=int, default=16)
parser.add_argument("--q_seq_len", dest="q_seq_len", type=int, default=1024)
parser.add_argument("--kv_seq_len", dest="kv_seq_len", type=int, default=1024)
parser.add_argument("--num_attn_heads", dest="num_attn_heads", type=int, default=16)
parser.add_argument("--head_dim", dest="head_dim", type=int, default=64)
parser.add_argument("--scale", dest="scale", type=float, default=1.0)
parser.add_argument("--dropout_rate", dest="dropout_rate", type=float, default=0.)
parser.add_argument("--bias", action="store_true")
parser.add_argument("--mask", action="store_true")
parser.add_argument("--causal_mask", action="store_true")
parser.add_argument("--fwd_only", dest='fwd_only', action="store_true")
parser.add_argument("--xla_dump_to", dest="xla_dump_to", type=str, default=None)
args = parser.parse_args()
if args.xla_dump_to:
xla_flags = f'--xla_dump_hlo_pass_re=.* --xla_dump_hlo_as_text --xla_dump_to={args.xla_dump_to}'
os.environ['XLA_FLAGS'] = xla_flags
dtype = jnp.bfloat16
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key, 2)
input_q = jax.random.uniform(key2, (4, args.batch_size, args.q_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
input_k = jax.random.uniform(key2, (4, args.batch_size, args.kv_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
input_v = jax.random.uniform(key2, (4, args.batch_size, args.kv_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
bias = jax.random.uniform(key2, (4, args.batch_size, args.num_attn_heads, args.q_seq_len, args.kv_seq_len), dtype=dtype) if args.bias else None
mask = jax.random.uniform(key2, (4, args.batch_size, args.num_attn_heads, args.q_seq_len, args.kv_seq_len), dtype=dtype) if args.mask else None
grad = jax.random.uniform(key2, (4, args.batch_size, args.q_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
if args.fwd_only:
jitted_fwd_step = jax.jit(vmap(partial(f, causal_mask=args.causal_mask, scale=args.scale, dropout_rate=0), in_axes=(0, 0, 0, 0, 0)))
out = jitted_fwd_step(input_q, input_k, input_v, bias, mask)
print(out[0,0,0,0,:20])
return out
else:
jitted_train_step = jax.jit(partial(train_step, causal_mask=args.causal_mask, scale=args.scale, dropout_rate=args.dropout_rate))
# out, input_grad = jitted_train_step(input_q, input_k, input_v, bias, mask, grad)
# print(input_grad[0][0,0,0,:20])
# return out, input_grad
devices = np.array(jax.local_devices())
devices = devices.reshape((2, 4, 1, 1, 1))
with Mesh(devices, ('p', 'b', 's', 'n', 'd')) as mesh:
input_q = jax.device_put(input_q, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
input_k = jax.device_put(input_k, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
input_v = jax.device_put(input_v, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
bias = jax.device_put(bias, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
mask = jax.device_put(mask, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
grad = jax.device_put(grad, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
pjitter = pjit(jitted_train_step,
in_shardings=(
PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None)),
out_shardings=(None,
(PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None),
PartitionSpec('p', 'b', None, None, None),
None,
None))
)
out, grads = pjitter(input_q, input_k, input_v, bias, mask, grad)
# print(make_jaxpr(pjitter)(input_q, input_k, input_v, bias, mask, grad))
print(grads[0][0,0,0,0,:20])
if __name__ == "__main__":
main()