mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support variable sequence lengths
This commit is contained in:
parent
f2068bb4ad
commit
558000df7c
@ -35,7 +35,7 @@ from jax._src.core import AxisName
|
||||
from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
dot_product_attention as cudnn_dot_product_attention, MaskType)
|
||||
from jax._src.numpy import util as numpy_util
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.typing import Array, ArrayLike, DType
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
|
||||
|
||||
@ -781,13 +781,48 @@ def _get_large_negative(dtype):
|
||||
dtype_max = jnp.finfo(dtype).max
|
||||
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
|
||||
|
||||
def _get_causal_mask(T, S, dtype):
|
||||
pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
||||
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype))
|
||||
return mask
|
||||
def _get_causal_mask(T, S):
|
||||
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
||||
return mask[None, None, :, :]
|
||||
|
||||
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
|
||||
q_indices = jnp.arange(0, T)[None, :, None]
|
||||
kv_indices = jnp.arange(0, S)[None, None, :]
|
||||
q_mask = q_indices < q_seqlen[:, None, None]
|
||||
kv_mask = kv_indices < kv_seqlen[:, None, None]
|
||||
mask = jnp.logical_and(q_mask, kv_mask)
|
||||
return mask[:, None, :, :]
|
||||
|
||||
def _get_padding_mask_encoded(T, q_seqlen):
|
||||
q_indices = jnp.arange(0, T)[None, :]
|
||||
mask = q_indices < q_seqlen[:, None]
|
||||
return mask[:, :, None, None]
|
||||
|
||||
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
|
||||
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
|
||||
return logits
|
||||
|
||||
combined_mask = jnp.ones_like(logits, dtype=jnp.bool_)
|
||||
if mask is not None:
|
||||
assert mask.dtype == jnp.bool_
|
||||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||||
|
||||
T, S = logits.shape[2], logits.shape[3]
|
||||
|
||||
if is_causal:
|
||||
mask = _get_causal_mask(T, S)
|
||||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||||
|
||||
if q_seqlen is not None and kv_seqlen is not None:
|
||||
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
|
||||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||||
|
||||
large_negative_number = _get_large_negative(logits.dtype)
|
||||
padded_logits = jnp.where(combined_mask, logits, large_negative_number)
|
||||
return padded_logits
|
||||
|
||||
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
|
||||
scale):
|
||||
scale, q_seqlen, kv_seqlen):
|
||||
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
|
||||
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
|
||||
preferred_element_type=logits_dtype)
|
||||
@ -797,24 +832,16 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
|
||||
if bias is not None:
|
||||
logits = (logits + bias).astype(logits.dtype)
|
||||
|
||||
if mask is not None:
|
||||
assert mask.dtype == jnp.bool_
|
||||
large_negative_number = _get_large_negative(logits.dtype)
|
||||
padded_logits = jnp.where(mask, logits, large_negative_number)
|
||||
else:
|
||||
padded_logits = logits
|
||||
|
||||
if is_causal:
|
||||
T, S = query.shape[1], key.shape[1]
|
||||
mask = jnp.broadcast_to(_get_causal_mask(T, S, logits.dtype),
|
||||
padded_logits.shape)
|
||||
padded_logits = padded_logits + mask
|
||||
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen)
|
||||
|
||||
# Softmax and it is always carried out in fp32.
|
||||
padded_logits = padded_logits.astype(jnp.float32)
|
||||
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype)
|
||||
|
||||
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
|
||||
if q_seqlen is not None and kv_seqlen is not None:
|
||||
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
|
||||
encoded *= mask.astype(encoded.dtype)
|
||||
return encoded
|
||||
|
||||
def _dot_product_attention_xla(
|
||||
@ -824,7 +851,9 @@ def _dot_product_attention_xla(
|
||||
bias: Array | None,
|
||||
mask: Array | None,
|
||||
is_causal: bool,
|
||||
scale: float):
|
||||
scale: float,
|
||||
q_seqlen: Array | None,
|
||||
kv_seqlen: Array | None):
|
||||
|
||||
B, T, N, H = query.shape
|
||||
_, S, K, _ = key.shape
|
||||
@ -843,9 +872,10 @@ def _dot_product_attention_xla(
|
||||
bias = _reshape_to_grouped(bias)
|
||||
mask = _reshape_to_grouped(mask)
|
||||
vmapped_fn = jax.vmap(_dot_product_attention_core,
|
||||
in_axes=(3, None, None, 2, 2, None, None),
|
||||
in_axes=(3, None, None, 2, 2, None, None, None, None),
|
||||
out_axes=3)
|
||||
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
|
||||
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
|
||||
q_seqlen, kv_seqlen)
|
||||
encoded = jnp.reshape(encoded, (B, T, N, H))
|
||||
return encoded
|
||||
|
||||
@ -858,6 +888,8 @@ def dot_product_attention(
|
||||
*,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
query_seq_lengths: ArrayLike | None = None,
|
||||
key_value_seq_lengths: ArrayLike | None = None,
|
||||
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
|
||||
r"""Scaled dot product attention function.
|
||||
|
||||
@ -903,6 +935,10 @@ def dot_product_attention(
|
||||
logits to mask out the non-causal parts of the attention matrix, but other
|
||||
implementations like `cudnn` will avoid computing the non-causal regions,
|
||||
providing speedups.
|
||||
query_seq_lengths: `int32` array of sequence lengths for query; shape
|
||||
:code:`(B)`
|
||||
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
|
||||
shape :code:`(B)`
|
||||
implementation: A string to control which implementation backend to use.
|
||||
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
|
||||
to `None`, which will automatically select the best available backend.
|
||||
@ -925,46 +961,64 @@ def dot_product_attention(
|
||||
value_arr = _ensure_4d(value)
|
||||
bias = _ensure_4d(bias) if bias is not None else None
|
||||
mask = _ensure_4d(mask) if mask is not None else None
|
||||
if query_seq_lengths is not None:
|
||||
query_seq_lengths = jnp.asarray(query_seq_lengths)
|
||||
if key_value_seq_lengths is not None:
|
||||
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
|
||||
|
||||
def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
|
||||
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
|
||||
dtype: DType | None, name: str) -> None:
|
||||
if t is None:
|
||||
return
|
||||
if t.ndim != len(shape):
|
||||
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}")
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
raise ValueError(f"{name} dtype should be {dtype}, but got {t.dtype}")
|
||||
for i in range(t.ndim):
|
||||
if shape[i] != -1 and t.shape[i] != shape[i]:
|
||||
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")
|
||||
|
||||
B, S, K, H = key_arr.shape
|
||||
_check_has_shape(value_arr, [B, S, K, H], 'value')
|
||||
_check_has_shape(query_arr, [B, -1, -1, H], 'query')
|
||||
_check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value')
|
||||
_check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query')
|
||||
_check_shape_and_dtype(mask, [-1] * 4, jnp.bool_, 'mask')
|
||||
_check_shape_and_dtype(bias, [-1] * 4, None, 'bias')
|
||||
_check_shape_and_dtype(query_seq_lengths, [B], jnp.int32,
|
||||
'query_seq_lengths')
|
||||
_check_shape_and_dtype(key_value_seq_lengths, [B], jnp.int32,
|
||||
'key_value_seq_lengths')
|
||||
if query_arr.shape[-2] % K != 0:
|
||||
raise ValueError(f"The number of query heads must be a multiple of "
|
||||
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
|
||||
if not (query_arr.dtype == key_arr.dtype == value_arr.dtype):
|
||||
raise ValueError(f"query/key/value should have the same dtype, but got "
|
||||
f"{query_arr.dtype} vs {key_arr.dtype} vs {value_arr.dtype}.")
|
||||
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
|
||||
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
|
||||
f"rank={mask.ndim}, dtype={mask.dtype}.")
|
||||
if bias is not None and bias.ndim != 4:
|
||||
raise ValueError(f"Bias must be a 4D tensor, but got rank={bias.ndim}.")
|
||||
|
||||
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
|
||||
|
||||
match implementation:
|
||||
case 'xla':
|
||||
out = _dot_product_attention_xla(
|
||||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val,
|
||||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
|
||||
scale=scale_val, q_seqlen=query_seq_lengths,
|
||||
kv_seqlen=key_value_seq_lengths,
|
||||
)
|
||||
case 'cudnn':
|
||||
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
|
||||
mask_type = MaskType.NO_MASK
|
||||
if query_seq_lengths is not None and is_causal:
|
||||
mask_type = MaskType.PADDING_CAUSAL
|
||||
elif is_causal:
|
||||
mask_type = MaskType.CAUSAL
|
||||
elif query_seq_lengths is not None:
|
||||
mask_type = MaskType.PADDING
|
||||
out = cudnn_dot_product_attention(
|
||||
query_arr, key_arr, value_arr, bias, mask, scale=scale_val, mask_type=mask_type
|
||||
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
|
||||
key_value_seq_lengths, scale=scale_val, mask_type=mask_type
|
||||
)
|
||||
case None:
|
||||
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
|
||||
# best backend.
|
||||
out = _dot_product_attention_xla(
|
||||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val,
|
||||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
|
||||
scale=scale_val, q_seqlen=query_seq_lengths,
|
||||
kv_seqlen=key_value_seq_lengths,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
|
@ -55,17 +55,21 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@parameterized.product(
|
||||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
|
||||
use_bias=[False, True],
|
||||
causal_mode=[None, 'is_causal', 'is_mask'],
|
||||
causal_mode=[None, 'attr', 'mask'],
|
||||
group_num=[1, 2, 4],
|
||||
use_vmap=[False, True],
|
||||
use_seqlen=[False, True],
|
||||
impl=['xla', 'cudnn'],
|
||||
)
|
||||
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
|
||||
group_num, use_vmap, impl):
|
||||
group_num, use_vmap, use_seqlen, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
if use_vmap and use_seqlen:
|
||||
raise unittest.SkipTest("vmap cannot be used together with variable "
|
||||
"seqence lengths")
|
||||
|
||||
sdpa = nn.dot_product_attention
|
||||
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
|
||||
@ -77,41 +81,60 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
bias = random.normal(keys[3], (1, N, T, S), dtype)
|
||||
else:
|
||||
bias = None
|
||||
if use_seqlen:
|
||||
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
|
||||
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
|
||||
else:
|
||||
q_seqlen = None
|
||||
kv_seqlen = None
|
||||
|
||||
is_causal = causal_mode == 'is_causal'
|
||||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
|
||||
is_causal = causal_mode == 'attr'
|
||||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
|
||||
|
||||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
|
||||
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
|
||||
|
||||
if impl == 'cudnn':
|
||||
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask)
|
||||
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask,
|
||||
query_seq_lengths=q_seqlen,
|
||||
key_value_seq_lengths=kv_seqlen)
|
||||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertIn('__cudnn$fmha', hlo)
|
||||
|
||||
if use_vmap:
|
||||
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
|
||||
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
|
||||
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
|
||||
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask)
|
||||
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask,
|
||||
query_seq_lengths=q_seqlen,
|
||||
key_value_seq_lengths=kv_seqlen)
|
||||
if use_vmap:
|
||||
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
|
||||
|
||||
out_ans = sdpa_ans(Q, K, V, bias, causal_mask)
|
||||
out_ans = sdpa_ans(Q, K, V, bias, causal_mask,
|
||||
query_seq_lengths=q_seqlen,
|
||||
key_value_seq_lengths=kv_seqlen)
|
||||
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
|
||||
use_bias=[False, True],
|
||||
causal_mode=[None, 'is_causal', 'is_mask'],
|
||||
causal_mode=[None, 'attr', 'mask'],
|
||||
group_num=[1, 2, 4],
|
||||
use_vmap=[False, True],
|
||||
use_seqlen=[False, True],
|
||||
impl=['xla', 'cudnn'],
|
||||
)
|
||||
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
|
||||
group_num, use_vmap, impl):
|
||||
group_num, use_vmap, use_seqlen, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
if use_vmap and use_seqlen:
|
||||
raise unittest.SkipTest("vmap cannot be used together with variable "
|
||||
"seqence lengths")
|
||||
if use_seqlen and use_bias and impl == 'cudnn':
|
||||
raise unittest.SkipTest("cudnn has limited support for dbias when using "
|
||||
"variable seqence lengths")
|
||||
|
||||
sdpa = nn.dot_product_attention
|
||||
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
|
||||
@ -124,24 +147,41 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
bias = random.normal(keys[4], (1, N, T, S), dtype)
|
||||
else:
|
||||
bias = None
|
||||
if use_seqlen:
|
||||
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
|
||||
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
|
||||
else:
|
||||
q_seqlen = None
|
||||
kv_seqlen = None
|
||||
|
||||
is_causal = causal_mode == 'is_causal'
|
||||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
|
||||
is_causal = causal_mode == 'attr'
|
||||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
|
||||
|
||||
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
|
||||
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
|
||||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
|
||||
_, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask)
|
||||
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad)
|
||||
# Convert the keyword arguments to positional ones.
|
||||
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
|
||||
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
|
||||
)
|
||||
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask,
|
||||
q_seqlen, kv_seqlen)
|
||||
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
|
||||
if G != 1:
|
||||
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
||||
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
||||
|
||||
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
|
||||
if use_vmap:
|
||||
if use_vmap and not use_seqlen:
|
||||
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
|
||||
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
|
||||
dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad)
|
||||
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
|
||||
else:
|
||||
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
|
||||
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
|
||||
)
|
||||
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen,
|
||||
kv_seqlen)
|
||||
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]
|
||||
|
||||
if impl == 'cudnn':
|
||||
lowered = jax.jit(sdpa_vjp_ans).lower(grad)
|
||||
|
Loading…
x
Reference in New Issue
Block a user