mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
New feature
This commit is contained in:
parent
859188b322
commit
541b3a3f75
@ -785,6 +785,14 @@ def _get_causal_mask(T, S):
|
||||
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
||||
return mask[None, None, :, :]
|
||||
|
||||
def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]):
|
||||
query_pos = jnp.array(range(T))
|
||||
key_pos = jnp.array(range(S))
|
||||
left_window, right_window = local_window_size
|
||||
left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window
|
||||
right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window
|
||||
return jnp.logical_and(right_mask, left_mask)[None, None, :, :]
|
||||
|
||||
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
|
||||
q_mask = True
|
||||
kv_mask = True
|
||||
@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen):
|
||||
mask = q_indices < q_seqlen[:, None]
|
||||
return mask[:, :, None, None]
|
||||
|
||||
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
|
||||
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
|
||||
local_window_size):
|
||||
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
|
||||
return logits
|
||||
|
||||
@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
|
||||
mask = _get_causal_mask(T, S)
|
||||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||||
|
||||
if local_window_size is not None:
|
||||
mask = _get_window_mask(T, S, local_window_size)
|
||||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||||
|
||||
if q_seqlen is not None or kv_seqlen is not None:
|
||||
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
|
||||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||||
@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
|
||||
return padded_logits
|
||||
|
||||
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
|
||||
scale, q_seqlen, kv_seqlen):
|
||||
scale, q_seqlen, kv_seqlen, local_window_size):
|
||||
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
|
||||
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
|
||||
preferred_element_type=logits_dtype)
|
||||
@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
|
||||
if bias is not None:
|
||||
logits = (logits + bias).astype(logits.dtype)
|
||||
|
||||
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen)
|
||||
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
|
||||
local_window_size)
|
||||
|
||||
# Softmax and it is always carried out in fp32.
|
||||
padded_logits = padded_logits.astype(jnp.float32)
|
||||
@ -857,7 +871,8 @@ def _dot_product_attention_xla(
|
||||
is_causal: bool,
|
||||
scale: float,
|
||||
q_seqlen: Array | None,
|
||||
kv_seqlen: Array | None):
|
||||
kv_seqlen: Array | None,
|
||||
local_window_size: tuple[int, int] | None):
|
||||
|
||||
B, T, N, H = query.shape
|
||||
_, S, K, _ = key.shape
|
||||
@ -875,11 +890,13 @@ def _dot_product_attention_xla(
|
||||
return t
|
||||
bias = _reshape_to_grouped(bias)
|
||||
mask = _reshape_to_grouped(mask)
|
||||
vmapped_fn = jax.vmap(_dot_product_attention_core,
|
||||
in_axes=(3, None, None, 2, 2, None, None, None, None),
|
||||
out_axes=3)
|
||||
vmapped_fn = jax.vmap(
|
||||
_dot_product_attention_core,
|
||||
in_axes=(3, None, None, 2, 2, None, None, None, None, None),
|
||||
out_axes=3,
|
||||
)
|
||||
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
|
||||
q_seqlen, kv_seqlen)
|
||||
q_seqlen, kv_seqlen, local_window_size)
|
||||
encoded = jnp.reshape(encoded, (B, T, N, H))
|
||||
return encoded
|
||||
|
||||
@ -894,6 +911,7 @@ def dot_product_attention(
|
||||
is_causal: bool = False,
|
||||
query_seq_lengths: ArrayLike | None = None,
|
||||
key_value_seq_lengths: ArrayLike | None = None,
|
||||
local_window_size: int | tuple[int, int] | None = None,
|
||||
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
|
||||
r"""Scaled dot product attention function.
|
||||
|
||||
@ -943,6 +961,12 @@ def dot_product_attention(
|
||||
:code:`(B)`
|
||||
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
|
||||
shape :code:`(B)`
|
||||
local_window_size: Window sizes to make self attention to attend to each
|
||||
token's local window. If set, this specifies the (left_window_size,
|
||||
right_window_size) for each token. E.g., if local_window_size == (3, 2)
|
||||
and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
|
||||
to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as
|
||||
a symmetric window (window_size, window_size).
|
||||
implementation: A string to control which implementation backend to use.
|
||||
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
|
||||
to `None`, which will automatically select the best available backend.
|
||||
@ -969,6 +993,8 @@ def dot_product_attention(
|
||||
query_seq_lengths = jnp.asarray(query_seq_lengths)
|
||||
if key_value_seq_lengths is not None:
|
||||
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
|
||||
if isinstance(local_window_size, int):
|
||||
local_window_size = (local_window_size, local_window_size)
|
||||
|
||||
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
|
||||
dtype: DType | None, name: str) -> None:
|
||||
@ -1003,6 +1029,7 @@ def dot_product_attention(
|
||||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
|
||||
scale=scale_val, q_seqlen=query_seq_lengths,
|
||||
kv_seqlen=key_value_seq_lengths,
|
||||
local_window_size=local_window_size,
|
||||
)
|
||||
case 'cudnn':
|
||||
use_padding = (
|
||||
@ -1022,9 +1049,21 @@ def dot_product_attention(
|
||||
mask_type = MaskType.CAUSAL
|
||||
elif use_padding:
|
||||
mask_type = MaskType.PADDING
|
||||
# CuDNN supports only the left window with an exclusive boundary when
|
||||
# causal mask is enabled.
|
||||
sliding_window = None
|
||||
if local_window_size is not None:
|
||||
l_window, r_window = local_window_size
|
||||
if r_window == 0 or mask_type == MaskType.CAUSAL:
|
||||
sliding_window = l_window + 1
|
||||
else:
|
||||
raise ValueError(f"cuDNN doesn't support right window: {r_window} "
|
||||
"when causal mask is not used.")
|
||||
|
||||
out = cudnn_dot_product_attention(
|
||||
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
|
||||
key_value_seq_lengths, scale=scale_val, mask_type=mask_type
|
||||
key_value_seq_lengths, scale=scale_val, mask_type=mask_type,
|
||||
sliding_window_length=sliding_window,
|
||||
)
|
||||
case None:
|
||||
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
|
||||
@ -1033,6 +1072,7 @@ def dot_product_attention(
|
||||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
|
||||
scale=scale_val, q_seqlen=query_seq_lengths,
|
||||
kv_seqlen=key_value_seq_lengths,
|
||||
local_window_size=local_window_size,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
|
@ -38,11 +38,11 @@ import jax.numpy as jnp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _is_required_cudnn_version_satisfied():
|
||||
def _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
return (
|
||||
jtu.is_cuda_compute_capability_at_least("8.0") and
|
||||
cuda_versions is not None and
|
||||
cuda_versions.cudnn_get_version() >= 8904
|
||||
cuda_versions.cudnn_get_version() >= min_cudnn_version
|
||||
)
|
||||
|
||||
def _check_cudnn_backend(fn, *args, **kwargs):
|
||||
@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
impl=['cudnn', 'xla'],
|
||||
)
|
||||
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
@ -102,13 +102,15 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.product(
|
||||
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
|
||||
('custom', 'padding'), ('bias', 'causal')],
|
||||
('custom', 'padding'), ('bias', 'causal'),
|
||||
('causal', 'sliding_window')],
|
||||
)
|
||||
def testDotProductAttentionMask(self, mask_mode):
|
||||
if not _is_required_cudnn_version_satisfied():
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if isinstance(mask_mode, str):
|
||||
mask_mode = (mask_mode,)
|
||||
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
|
||||
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
B, S, T, N, H = 2, 128, 128, 4, 32
|
||||
@ -119,6 +121,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
grad = random.normal(keys[3], (B, T, N, H), dtype)
|
||||
bias, mask = None, None
|
||||
q_seqlen, kv_seqlen = None, None
|
||||
window_size = None
|
||||
|
||||
is_causal = 'causal' in mask_mode
|
||||
if 'padding' in mask_mode:
|
||||
@ -130,6 +133,8 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
mask = custom_mask[None, None, :, :]
|
||||
if 'bias' in mask_mode:
|
||||
bias = random.normal(keys[4], (1, N, T, S), dtype)
|
||||
if 'sliding_window' in mask_mode:
|
||||
window_size = (3, 2) if is_causal else (3, 0)
|
||||
|
||||
sdpa = nn.dot_product_attention
|
||||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
|
||||
@ -141,9 +146,11 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
# Convert the kargs to positional args for the jax.vjp.
|
||||
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
|
||||
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
|
||||
local_window_size=window_size,
|
||||
)
|
||||
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
|
||||
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
|
||||
local_window_size=window_size,
|
||||
)
|
||||
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
|
||||
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
|
||||
|
Loading…
x
Reference in New Issue
Block a user