add is_training && fix seqlen/head_dim checks

This commit is contained in:
Cjkkkk 2024-03-14 14:06:38 -07:00
parent 9a00721a54
commit 204ee7ff0b
2 changed files with 187 additions and 85 deletions

View File

@ -200,52 +200,57 @@ def check_qkv_layout(query, key, value):
"query should have layout [batch, q_seq, num_heads, head_dim], " \
"key and value should have layout [batch, kv_seq, num_heads, head_dim].")
def check_is_flash_attention(query, key):
def check_is_flash_attention(query, key, cudnn_version, has_bias, is_training):
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_sqe_len, _, _ = key.shape
is_cross_attention = q_seq_len != kv_sqe_len
_, kv_seq_len, _, _ = key.shape
# check if attention pattern is supported by flash attention or fused attention
if q_seq_len > 512 and kv_sqe_len > 512 and head_dim in [64, 128]:
# check if flash attention is supported
is_flash_attention = True
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
if q_seq_len <= 512 and kv_seq_len <= 512 and head_dim == 64 \
and (not is_training or q_seq_len % 64 == 0 and kv_seq_len % 64 == 0):
# check if regular fused attention is supported
# for training, seqlen should be divisible by 64
is_flash_attention = False
elif head_dim <= 128 and head_dim % 8 == 0 \
and (not is_training or not has_bias or q_seq_len % 2 == 0 and kv_seq_len % 2 == 0):
# check if flash attention is supported
# for training, for patterns with bias, seqlen should be divisible by 2
is_flash_attention = True
else:
raise NotImplementedError(
f"Unsupported sequence length Q {q_seq_len}, KV {kv_sqe_len} and head dim {head_dim}.")
return is_flash_attention, is_cross_attention
def check_cudnn_version(is_flash_attention, is_cross_attention):
# check if cuDNN is installed and if cuDNN version contraint is satisfied
if cuda_versions is None:
raise RuntimeError("cuDNN is not detected.")
elif is_flash_attention:
if not is_cross_attention and cuda_versions.cudnn_get_version() < 8903:
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
if is_cross_attention and cuda_versions.cudnn_get_version() < 8904:
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
f"Unsupported sequence length Q {q_seq_len}, KV {kv_seq_len} and head dim {head_dim}.")
# check if minimum cudnn version requirement is satisfied
if is_flash_attention and cudnn_version < 8904:
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
elif not is_flash_attention and cudnn_version < 8901:
raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.")
return is_flash_attention
def check_cudnn_version():
# check if cuDNN is installed
if cuda_versions is None:
raise RuntimeError("cuDNN is not detected.")
return cuda_versions.cudnn_get_version()
def _dot_product_attention_fwd(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
output, _ = _dot_product_attention_fwd_p_wrapper.bind(
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
is_causal_mask=is_causal_mask, is_training=is_training)
output = outputs[0]
return output
def _dot_product_attention_fwd_rule(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
res = (query, key, value, bias, mask, activation, output)
return output, res
is_causal_mask=is_causal_mask, is_training=is_training)
res = (query, key, value, bias, mask, outputs[1], outputs[0]) if is_training else None
return outputs[0], res
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output):
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, res, grad_output):
query, key, value, bias, mask, activation, fwd_output = res
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, mask, activation, fwd_output, grad_output,
@ -256,13 +261,13 @@ def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is
return grads
def _dot_product_attention_fwd_impl(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
# args: {Q, K, V, mask*, bias*}
output, activation = _dot_product_attention_fwd_p.bind(
outputs = _dot_product_attention_fwd_p.bind(
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
return output, activation
is_causal_mask=is_causal_mask, is_training=is_training)
return outputs
def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
@ -275,23 +280,34 @@ def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, f
return grads
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_seq_len, _, _ = key.shape
output_shape = (batch, q_seq_len, num_heads, head_dim)
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
softmax_stat_shape = (batch, num_heads, q_seq_len)
if q_seq_len > 512:
if is_flash_attention:
# is flash attention
if is_training:
return (
core.ShapedArray(output_shape, query_dtype), # output
core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
)
else:
return (
core.ShapedArray(output_shape, query_dtype), # output
)
if is_training:
return (
core.ShapedArray(output_shape, query_dtype), # output
core.ShapedArray(activation_shape, query_dtype), # activation
)
else:
return (
core.ShapedArray(output_shape, query_dtype), # output
core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
)
return (
core.ShapedArray(output_shape, query_dtype), # output
core.ShapedArray(activation_shape, query_dtype), # activation
)
def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output,
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
@ -312,7 +328,7 @@ def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activatio
)
def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -344,19 +360,33 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False)
# create output types and layouts
if is_flash_attention:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(scratch_shape, scratch_type),
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
]
result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape)
if is_training:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(scratch_shape, scratch_type),
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
]
result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape)
else:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(scratch_shape, scratch_type)
]
result_layouts = [output_layout] + default_layouts(scratch_shape)
else:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(scratch_shape, scratch_type),
ir.RankedTensorType.get(activation_shape, query_type.element_type),
]
result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape)
if is_training:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(scratch_shape, scratch_type),
ir.RankedTensorType.get(activation_shape, query_type.element_type),
]
result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape)
else:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(scratch_shape, scratch_type),
]
result_layouts = [output_layout] + default_layouts(scratch_shape)
# create custom call here
out = mlir.custom_call(
custom_call_name,
@ -368,7 +398,10 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
)
# drop scratch memory
# output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim)
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]]
if is_training:
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]]
else:
return [hlo.transpose(out.results[0], output_transpose_perm)]
def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
@ -451,11 +484,14 @@ def _check_valid_batch_dims(bdims):
raise NotImplementedError("Currently only support batch_dim in [0, None], " \
f"but got {dim=}")
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, mask = batched_args
query_bdim = batch_dims[0]
out_bdims = query_bdim, query_bdim
if is_training:
out_bdims = query_bdim, query_bdim
else:
out_bdims = (query_bdim,)
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
*_, kv_seq_len, _, _ = key.shape
@ -470,19 +506,24 @@ def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed,
if has_mask:
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, mask,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
is_causal_mask=is_causal_mask, is_training=is_training)
# reshape to original shape
output = outputs[0]
output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim))
if is_flash_attention:
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
if is_training:
activation = outputs[1]
if is_flash_attention:
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
else:
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
return (output, activation), out_bdims
else:
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
return (output, activation), out_bdims
return (output,), out_bdims
def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
_check_valid_batch_dims(batch_dims)
@ -556,7 +597,7 @@ def _check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_
raise ValueError("Sharding on mask sequence dim is not allowed.")
# fwd custom partition
def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args):
def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training):
# only sharding on batch and num_head dim is allowed
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
@ -569,21 +610,24 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args):
_check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec)
# keep out sharding same as query sharding since they have same shape
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
# activation sharding
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
activation_sharding = NamedSharding(mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None))
return (out_sharding, activation_sharding)
if is_training:
# activation sharding
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
activation_sharding = NamedSharding(mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None))
return [out_sharding, activation_sharding]
return [out_sharding]
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args)
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10,11))
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args)
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
# bwd custom partition
@ -673,7 +717,7 @@ dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p)
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper)
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11))
def _dot_product_attention(query: Array,
key: Array,
value: Array,
@ -684,11 +728,13 @@ def _dot_product_attention(query: Array,
dropout_rate: float,
variadic_args: tuple[bool, ...],
is_flash_attention: bool,
is_causal_mask: bool):
is_causal_mask: bool,
is_training: bool):
output = _dot_product_attention_fwd(
query, key, value, bias, mask,
scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
is_training=is_training)
return output
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
@ -704,7 +750,8 @@ def dot_product_attention(query: Array,
scale: float = 1.0,
is_causal_mask: bool = False,
seed: int = 42,
dropout_rate: float = 0.):
dropout_rate: float = 0.,
is_training = False):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
@ -725,16 +772,19 @@ def dot_product_attention(query: Array,
mask: mask used mask out logits with shape of `[batch, num_heads,
q_length, kv_length]`.
scale: scale for the query.
is_causal_mask: choose to apply a causal mask or not.
seed: used for dropout mask generation.
dropout_rate: dropout rate.
is_training: choose to save activation or not.
Returns:
Output of shape `[batch, q_length, num_heads, v_depth_per_head]`.
"""
# check if query, key and value layout meets cuDNN layout requirement
# check if cuDNN is installed
cudnn_version = check_cudnn_version()
# check query, key and value shape and data type
check_qkv_layout(query, key, value)
# check if flash attention is supported for this attention pattern
is_flash_attention, is_cross_attention = check_is_flash_attention(query, key)
# check if cuDNN is installed and if cuDNN version is sufficient
check_cudnn_version(is_flash_attention, is_cross_attention)
is_flash_attention = check_is_flash_attention(query, key, cudnn_version, bias is not None, is_training)
if mask is not None and is_causal_mask:
raise ValueError("can not apply a mask and generate a causal_mask at the same time.")
if not is_flash_attention and is_causal_mask:
@ -747,5 +797,5 @@ def dot_product_attention(query: Array,
output = _dot_product_attention(
query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask)
is_flash_attention, is_causal_mask, is_training)
return output

View File

@ -16,7 +16,7 @@ from functools import partial
from absl.testing import absltest
from typing import Optional
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true --xla_dump_hlo_as_text --xla_dump_to=./hlo'
import numpy as np
import jax
@ -25,7 +25,7 @@ from jax.sharding import Mesh
from jax.sharding import PartitionSpec, NamedSharding
from jax._src import config
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, check_is_flash_attention
config.parse_flags_with_absl()
Array = jnp.ndarray
@ -43,7 +43,7 @@ def sdpa_train(query: Array,
# convert bool mask to dtype mask
mask = mask.astype(query.dtype)
out, sdpa_vjp = jax.vjp(
partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate, is_training=True),
query, key, value, bias, mask)
query_grad, key_grad, value_grad, _, _ = sdpa_vjp(grad)
return out, (query_grad, key_grad, value_grad)
@ -201,6 +201,58 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
@jtu.run_on_devices("cuda")
def test_sdpa_inference(self):
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
query = jax.random.normal(
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
key = jax.random.normal(
k2, (4, 1024, 4, 64), dtype=jnp.bfloat16)
value = jax.random.normal(
k3, (4, 1024, 4, 64), dtype=jnp.bfloat16)
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ('dp', 'tp')) as mesh:
qkv_spec = PartitionSpec('dp', None, 'tp', None)
qkv_sharding = NamedSharding(mesh, qkv_spec)
replicated = NamedSharding(mesh, PartitionSpec())
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, replicated, replicated)
out_shardings = replicated
query = jax.device_put(query, qkv_sharding)
key = jax.device_put(key, qkv_sharding)
value = jax.device_put(value, qkv_sharding)
jitted_sdpa_inference = jax.jit(
partial(dot_product_attention, scale=1.0, is_causal_mask=False, dropout_rate=0),
in_shardings=in_shardings,
out_shardings=out_shardings
)
jitted_sdpa_inference_ref = jax.jit(
partial(sdpa_ref, scale=1.0, is_causal_mask=False, dropout_rate=0),
in_shardings=in_shardings,
out_shardings=out_shardings
)
out = jitted_sdpa_inference(query, key, value, None, None)
out_ref = jitted_sdpa_inference_ref(query, key, value, None, None)
self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
def test_sdpa_utils(self):
test_cases = {
(256, 512, 64, 8905, False, False): False,
(1, 257, 64, 8905, False, True): True,
(1, 1024, 64, 8905, False, False): True,
(1024, 1024, 64, 8905, False, False): True,
(1024, 1024, 128, 8905, False, False): True,
}
for k, v in test_cases.items():
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k
query = jnp.empty((4, sql_q, 4, head_dim))
key = jnp.empty((4, sql_v, 4, head_dim))
self.assertEqual(check_is_flash_attention(query, key, cudnn_version, has_bias, is_training), v)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())