Support variable sequence lengths

This commit is contained in:
kaixih 2024-07-18 17:03:49 +00:00
parent f2068bb4ad
commit 558000df7c
2 changed files with 148 additions and 54 deletions

View File

@ -35,7 +35,7 @@ from jax._src.core import AxisName
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike
from jax._src.typing import Array, ArrayLike, DType
from jax._src.ops.special import logsumexp as _logsumexp
@ -781,13 +781,48 @@ def _get_large_negative(dtype):
dtype_max = jnp.finfo(dtype).max
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
def _get_causal_mask(T, S, dtype):
pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype))
return mask
def _get_causal_mask(T, S):
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return mask[None, None, :, :]
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
q_indices = jnp.arange(0, T)[None, :, None]
kv_indices = jnp.arange(0, S)[None, None, :]
q_mask = q_indices < q_seqlen[:, None, None]
kv_mask = kv_indices < kv_seqlen[:, None, None]
mask = jnp.logical_and(q_mask, kv_mask)
return mask[:, None, :, :]
def _get_padding_mask_encoded(T, q_seqlen):
q_indices = jnp.arange(0, T)[None, :]
mask = q_indices < q_seqlen[:, None]
return mask[:, :, None, None]
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
return logits
combined_mask = jnp.ones_like(logits, dtype=jnp.bool_)
if mask is not None:
assert mask.dtype == jnp.bool_
combined_mask = jnp.logical_and(combined_mask, mask)
T, S = logits.shape[2], logits.shape[3]
if is_causal:
mask = _get_causal_mask(T, S)
combined_mask = jnp.logical_and(combined_mask, mask)
if q_seqlen is not None and kv_seqlen is not None:
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
combined_mask = jnp.logical_and(combined_mask, mask)
large_negative_number = _get_large_negative(logits.dtype)
padded_logits = jnp.where(combined_mask, logits, large_negative_number)
return padded_logits
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale):
scale, q_seqlen, kv_seqlen):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
@ -797,24 +832,16 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
if bias is not None:
logits = (logits + bias).astype(logits.dtype)
if mask is not None:
assert mask.dtype == jnp.bool_
large_negative_number = _get_large_negative(logits.dtype)
padded_logits = jnp.where(mask, logits, large_negative_number)
else:
padded_logits = logits
if is_causal:
T, S = query.shape[1], key.shape[1]
mask = jnp.broadcast_to(_get_causal_mask(T, S, logits.dtype),
padded_logits.shape)
padded_logits = padded_logits + mask
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen)
# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(jnp.float32)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype)
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
if q_seqlen is not None and kv_seqlen is not None:
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
encoded *= mask.astype(encoded.dtype)
return encoded
def _dot_product_attention_xla(
@ -824,7 +851,9 @@ def _dot_product_attention_xla(
bias: Array | None,
mask: Array | None,
is_causal: bool,
scale: float):
scale: float,
q_seqlen: Array | None,
kv_seqlen: Array | None):
B, T, N, H = query.shape
_, S, K, _ = key.shape
@ -843,9 +872,10 @@ def _dot_product_attention_xla(
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
vmapped_fn = jax.vmap(_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None),
in_axes=(3, None, None, 2, 2, None, None, None, None),
out_axes=3)
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
q_seqlen, kv_seqlen)
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded
@ -858,6 +888,8 @@ def dot_product_attention(
*,
scale: float | None = None,
is_causal: bool = False,
query_seq_lengths: ArrayLike | None = None,
key_value_seq_lengths: ArrayLike | None = None,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
r"""Scaled dot product attention function.
@ -903,6 +935,10 @@ def dot_product_attention(
logits to mask out the non-causal parts of the attention matrix, but other
implementations like `cudnn` will avoid computing the non-causal regions,
providing speedups.
query_seq_lengths: `int32` array of sequence lengths for query; shape
:code:`(B)`
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
shape :code:`(B)`
implementation: A string to control which implementation backend to use.
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
to `None`, which will automatically select the best available backend.
@ -925,46 +961,64 @@ def dot_product_attention(
value_arr = _ensure_4d(value)
bias = _ensure_4d(bias) if bias is not None else None
mask = _ensure_4d(mask) if mask is not None else None
if query_seq_lengths is not None:
query_seq_lengths = jnp.asarray(query_seq_lengths)
if key_value_seq_lengths is not None:
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
dtype: DType | None, name: str) -> None:
if t is None:
return
if t.ndim != len(shape):
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}")
if dtype is not None and t.dtype != dtype:
raise ValueError(f"{name} dtype should be {dtype}, but got {t.dtype}")
for i in range(t.ndim):
if shape[i] != -1 and t.shape[i] != shape[i]:
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")
B, S, K, H = key_arr.shape
_check_has_shape(value_arr, [B, S, K, H], 'value')
_check_has_shape(query_arr, [B, -1, -1, H], 'query')
_check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value')
_check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query')
_check_shape_and_dtype(mask, [-1] * 4, jnp.bool_, 'mask')
_check_shape_and_dtype(bias, [-1] * 4, None, 'bias')
_check_shape_and_dtype(query_seq_lengths, [B], jnp.int32,
'query_seq_lengths')
_check_shape_and_dtype(key_value_seq_lengths, [B], jnp.int32,
'key_value_seq_lengths')
if query_arr.shape[-2] % K != 0:
raise ValueError(f"The number of query heads must be a multiple of "
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
if not (query_arr.dtype == key_arr.dtype == value_arr.dtype):
raise ValueError(f"query/key/value should have the same dtype, but got "
f"{query_arr.dtype} vs {key_arr.dtype} vs {value_arr.dtype}.")
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
f"rank={mask.ndim}, dtype={mask.dtype}.")
if bias is not None and bias.ndim != 4:
raise ValueError(f"Bias must be a 4D tensor, but got rank={bias.ndim}.")
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
match implementation:
case 'xla':
out = _dot_product_attention_xla(
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val,
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
)
case 'cudnn':
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
mask_type = MaskType.NO_MASK
if query_seq_lengths is not None and is_causal:
mask_type = MaskType.PADDING_CAUSAL
elif is_causal:
mask_type = MaskType.CAUSAL
elif query_seq_lengths is not None:
mask_type = MaskType.PADDING
out = cudnn_dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, scale=scale_val, mask_type=mask_type
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
key_value_seq_lengths, scale=scale_val, mask_type=mask_type
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
# best backend.
out = _dot_product_attention_xla(
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val,
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")

View File

@ -55,17 +55,21 @@ class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
causal_mode=[None, 'attr', 'mask'],
group_num=[1, 2, 4],
use_vmap=[False, True],
use_seqlen=[False, True],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
group_num, use_vmap, impl):
group_num, use_vmap, use_seqlen, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
if use_vmap and use_seqlen:
raise unittest.SkipTest("vmap cannot be used together with variable "
"seqence lengths")
sdpa = nn.dot_product_attention
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
@ -77,41 +81,60 @@ class NNFunctionsTest(jtu.JaxTestCase):
bias = random.normal(keys[3], (1, N, T, S), dtype)
else:
bias = None
if use_seqlen:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
else:
q_seqlen = None
kv_seqlen = None
is_causal = causal_mode == 'is_causal'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
is_causal = causal_mode == 'attr'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
if impl == 'cudnn':
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask)
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn('__cudnn$fmha', hlo)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask)
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
out_ans = sdpa_ans(Q, K, V, bias, causal_mask)
out_ans = sdpa_ans(Q, K, V, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
causal_mode=[None, 'attr', 'mask'],
group_num=[1, 2, 4],
use_vmap=[False, True],
use_seqlen=[False, True],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
group_num, use_vmap, impl):
group_num, use_vmap, use_seqlen, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
if use_vmap and use_seqlen:
raise unittest.SkipTest("vmap cannot be used together with variable "
"seqence lengths")
if use_seqlen and use_bias and impl == 'cudnn':
raise unittest.SkipTest("cudnn has limited support for dbias when using "
"variable seqence lengths")
sdpa = nn.dot_product_attention
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
@ -124,24 +147,41 @@ class NNFunctionsTest(jtu.JaxTestCase):
bias = random.normal(keys[4], (1, N, T, S), dtype)
else:
bias = None
if use_seqlen:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
else:
q_seqlen = None
kv_seqlen = None
is_causal = causal_mode == 'is_causal'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
is_causal = causal_mode == 'attr'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
_, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask)
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad)
# Convert the keyword arguments to positional ones.
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask,
q_seqlen, kv_seqlen)
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
if G != 1:
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
if use_vmap:
if use_vmap and not use_seqlen:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad)
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
else:
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
)
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen,
kv_seqlen)
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]
if impl == 'cudnn':
lowered = jax.jit(sdpa_vjp_ans).lower(grad)