New feature

This commit is contained in:
kaixih 2024-08-26 17:32:38 +00:00
parent 859188b322
commit 541b3a3f75
2 changed files with 62 additions and 15 deletions

View File

@ -785,6 +785,14 @@ def _get_causal_mask(T, S):
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return mask[None, None, :, :]
def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]):
query_pos = jnp.array(range(T))
key_pos = jnp.array(range(S))
left_window, right_window = local_window_size
left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window
right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window
return jnp.logical_and(right_mask, left_mask)[None, None, :, :]
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
q_mask = True
kv_mask = True
@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen):
mask = q_indices < q_seqlen[:, None]
return mask[:, :, None, None]
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size):
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
return logits
@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
mask = _get_causal_mask(T, S)
combined_mask = jnp.logical_and(combined_mask, mask)
if local_window_size is not None:
mask = _get_window_mask(T, S, local_window_size)
combined_mask = jnp.logical_and(combined_mask, mask)
if q_seqlen is not None or kv_seqlen is not None:
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
combined_mask = jnp.logical_and(combined_mask, mask)
@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
return padded_logits
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale, q_seqlen, kv_seqlen):
scale, q_seqlen, kv_seqlen, local_window_size):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
if bias is not None:
logits = (logits + bias).astype(logits.dtype)
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen)
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size)
# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(jnp.float32)
@ -857,7 +871,8 @@ def _dot_product_attention_xla(
is_causal: bool,
scale: float,
q_seqlen: Array | None,
kv_seqlen: Array | None):
kv_seqlen: Array | None,
local_window_size: tuple[int, int] | None):
B, T, N, H = query.shape
_, S, K, _ = key.shape
@ -875,11 +890,13 @@ def _dot_product_attention_xla(
return t
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
vmapped_fn = jax.vmap(_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None, None, None),
out_axes=3)
vmapped_fn = jax.vmap(
_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None, None, None, None),
out_axes=3,
)
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
q_seqlen, kv_seqlen)
q_seqlen, kv_seqlen, local_window_size)
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded
@ -894,6 +911,7 @@ def dot_product_attention(
is_causal: bool = False,
query_seq_lengths: ArrayLike | None = None,
key_value_seq_lengths: ArrayLike | None = None,
local_window_size: int | tuple[int, int] | None = None,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
r"""Scaled dot product attention function.
@ -943,6 +961,12 @@ def dot_product_attention(
:code:`(B)`
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
shape :code:`(B)`
local_window_size: Window sizes to make self attention to attend to each
token's local window. If set, this specifies the (left_window_size,
right_window_size) for each token. E.g., if local_window_size == (3, 2)
and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as
a symmetric window (window_size, window_size).
implementation: A string to control which implementation backend to use.
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
to `None`, which will automatically select the best available backend.
@ -969,6 +993,8 @@ def dot_product_attention(
query_seq_lengths = jnp.asarray(query_seq_lengths)
if key_value_seq_lengths is not None:
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
if isinstance(local_window_size, int):
local_window_size = (local_window_size, local_window_size)
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
dtype: DType | None, name: str) -> None:
@ -1003,6 +1029,7 @@ def dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
)
case 'cudnn':
use_padding = (
@ -1022,9 +1049,21 @@ def dot_product_attention(
mask_type = MaskType.CAUSAL
elif use_padding:
mask_type = MaskType.PADDING
# CuDNN supports only the left window with an exclusive boundary when
# causal mask is enabled.
sliding_window = None
if local_window_size is not None:
l_window, r_window = local_window_size
if r_window == 0 or mask_type == MaskType.CAUSAL:
sliding_window = l_window + 1
else:
raise ValueError(f"cuDNN doesn't support right window: {r_window} "
"when causal mask is not used.")
out = cudnn_dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
key_value_seq_lengths, scale=scale_val, mask_type=mask_type
key_value_seq_lengths, scale=scale_val, mask_type=mask_type,
sliding_window_length=sliding_window,
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
@ -1033,6 +1072,7 @@ def dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")

View File

@ -38,11 +38,11 @@ import jax.numpy as jnp
config.parse_flags_with_absl()
def _is_required_cudnn_version_satisfied():
def _is_required_cudnn_version_satisfied(min_cudnn_version):
return (
jtu.is_cuda_compute_capability_at_least("8.0") and
cuda_versions is not None and
cuda_versions.cudnn_get_version() >= 8904
cuda_versions.cudnn_get_version() >= min_cudnn_version
)
def _check_cudnn_backend(fn, *args, **kwargs):
@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
@ -102,13 +102,15 @@ class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
('custom', 'padding'), ('bias', 'causal')],
('custom', 'padding'), ('bias', 'causal'),
('causal', 'sliding_window')],
)
def testDotProductAttentionMask(self, mask_mode):
if not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16
B, S, T, N, H = 2, 128, 128, 4, 32
@ -119,6 +121,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
grad = random.normal(keys[3], (B, T, N, H), dtype)
bias, mask = None, None
q_seqlen, kv_seqlen = None, None
window_size = None
is_causal = 'causal' in mask_mode
if 'padding' in mask_mode:
@ -130,6 +133,8 @@ class NNFunctionsTest(jtu.JaxTestCase):
mask = custom_mask[None, None, :, :]
if 'bias' in mask_mode:
bias = random.normal(keys[4], (1, N, T, S), dtype)
if 'sliding_window' in mask_mode:
window_size = (3, 2) if is_causal else (3, 0)
sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
@ -141,9 +146,11 @@ class NNFunctionsTest(jtu.JaxTestCase):
# Convert the kargs to positional args for the jax.vjp.
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
local_window_size=window_size,
)
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
local_window_size=window_size,
)
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)