Support dynamic masks in splash attention

This commit is contained in:
Gleb Pobudzey 2024-12-02 20:13:20 +00:00
parent 95cb0eb1c9
commit c0d23af42c
4 changed files with 286 additions and 18 deletions

View File

@ -46,7 +46,7 @@ NT_DIM_NUMBERS = (((1,), (1,)), ((), ())) # RHS transposed
class SegmentIds(NamedTuple):
"""SegmentIds for Q and KV sequences.
SegmentIds are a mechanims to ensure that there is no cross-attention between
SegmentIds are a mechanism to ensure that there is no cross-attention between
segments (fraction of a sequence) that have been concatenated together into a
sequence. Each array is a list of ids (integers). Only tokens with the same
id are allowed to attend to each other.
@ -2392,7 +2392,7 @@ class SplashAttentionKernel:
def _make_splash_attention(
mask: np.ndarray | mask_lib.MultiHeadMask,
mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask,
*,
block_sizes: BlockSizes | None = None,
is_mqa: bool,
@ -2415,14 +2415,26 @@ def _make_splash_attention(
if block_sizes is None:
block_sizes = BlockSizes.get_default()
fwd_mask_info, mask_function_fwd = mask_info_lib.process_mask(
process_mask_fn = (
mask_info_lib.process_dynamic_mask
if isinstance(mask, jax.Array)
else mask_info_lib.process_mask
)
process_mask_dvk_fn = (
mask_info_lib.process_dynamic_mask_dkv
if isinstance(mask, jax.Array)
else mask_info_lib.process_mask_dkv
)
fwd_mask_info, mask_function_fwd = process_mask_fn(
mask,
(block_sizes.block_q, block_sizes.block_kv),
downcast_smem_data=downcast_smem_data,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
)
fwd_mask_info = tree_util.tree_map(jnp.array, fwd_mask_info)
dq_mask_info = None
@ -2432,7 +2444,7 @@ def _make_splash_attention(
dq_mask_info = None
else:
bq_dq, bkv_dq = block_sizes.block_q_dq, block_sizes.block_kv_dq
dq_mask_info, mask_function_dq = mask_info_lib.process_mask(
dq_mask_info, mask_function_dq = process_mask_fn(
mask,
(bq_dq, bkv_dq),
downcast_smem_data=downcast_smem_data,
@ -2442,7 +2454,7 @@ def _make_splash_attention(
assert (mask_function_fwd is None) == (mask_function_dq is None)
dq_mask_info = tree_util.tree_map(jnp.array, dq_mask_info)
bq_dkv, bkv_dkv = block_sizes.block_q_dkv, block_sizes.block_kv_dkv
dkv_mask_info, mask_function_dkv = mask_info_lib.process_mask_dkv(
dkv_mask_info, mask_function_dkv = process_mask_dvk_fn(
mask,
(bq_dkv, bkv_dkv),
downcast_smem_data=downcast_smem_data,

View File

@ -18,9 +18,13 @@ from __future__ import annotations
import collections
from collections.abc import Callable
import functools
import math
from typing import NamedTuple
import jax
from jax import util as jax_util
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
import jax.numpy as jnp
import numpy as np
# mypy: ignore-errors
@ -65,10 +69,10 @@ class MaskInfo(NamedTuple):
causal this is just np.arange(q_sequence_length).
"""
data_next: np.ndarray | None
mask_next: np.ndarray | None
block_mask: np.ndarray | None
partial_mask_blocks: np.ndarray | None
data_next: np.ndarray | jax.Array | None
mask_next: np.ndarray | jax.Array | None
block_mask: np.ndarray | jax.Array | None
partial_mask_blocks: np.ndarray | jax.Array | None
q_sequence: np.ndarray | None
@ -245,7 +249,7 @@ def _get_mask_info_for_shard(
mask_next = np.zeros(output_shape, dtype=np.int32)
data_next = np.zeros(output_shape, dtype=np.int32)
# If the mask is completelly zero'd out return freshly initialized outputs.
# If the mask is completely zero'd out return freshly initialized outputs.
if not data_coords:
return data_next, mask_next
@ -304,6 +308,152 @@ def _get_mask_info_for_shard(
return data_next, mask_next
def _process_dynamic_mask(
mask: jax.Array,
block_shape: tuple[int, int],
is_dkv: bool,
*,
downcast_smem_data: bool = True,
head_shards: int = 1,
q_seq_shards: int = 1,
shrink_grid: bool = True,
) -> tuple[MaskInfo, None]:
"""Similar to `_process_mask` but the mask must be a dynamic array.
Since the mask is dynamic, we can't know the exact number of partial mask
blocks at trace time. Therefore, the entire mask is materialized in
`partial_mask_blocks`.
Note that we can still populate MaskInfo to skip fully-masked blocks.
Args:
mask: A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense
mask to process.
block_shape: A Tuple[int, int] representing the shape of the Pallas grid
block.
is_dkv: True if we are processing the dKV mask
downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to
a data type smaller than np.int32 (if possible).
head_shards: Number of head shards of the mesh in which the kernel is
launched.
q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is
launched.
shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored.
Returns:
`MaskInfo`, a sparse representation of the dense mask.
Raises:
ValueError: if the input mask is invalid or the block sizes are not
compatible with the mask sizes.
"""
del shrink_grid
# TODO(pobudzey): Properly support sharding.
if head_shards != 1 or q_seq_shards != 1:
raise ValueError('Dynamic mask processing does not support sharding.')
if len(mask.shape) != 3:
raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.')
if mask.dtype != jnp.bool:
raise ValueError(f'Expected a bool mask, instead got: {mask.dtype}.')
head_count, q_seq_len, kv_seq_len = mask.shape
q_block_size, kv_block_size = block_shape
q_blocks_count, q_mod = divmod(q_seq_len, q_block_size)
kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_block_size)
if q_mod != 0:
raise ValueError(f'{q_block_size=} should divide {q_seq_len=}.')
if kv_mod != 0:
raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.')
block_mask_shape = (
head_count,
q_blocks_count,
kv_blocks_count,
)
# Tile the last 2 dimensions of the mask into 2D tiles of size `block_shape`.
partial_mask_blocks = (
mask.reshape(
head_count,
q_blocks_count,
q_block_size,
kv_blocks_count,
kv_block_size,
)
.swapaxes(-2, -3)
.astype(np.bool_)
)
# The block mask is 2 for all blocks with all entries set to True and 1 for
# blocks with a mix of True and False entries.
is_full_mask = jnp.all(partial_mask_blocks, axis=(-1, -2))
is_empty_mask = jnp.logical_not(jnp.any(partial_mask_blocks, axis=(-1, -2)))
block_mask = jnp.ones(block_mask_shape, dtype=np.int32)
block_mask = jnp.where(is_full_mask, 2, block_mask)
block_mask = jnp.where(is_empty_mask, 0, block_mask)
# TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline.
mask_next = jnp.where(
jnp.logical_or(is_empty_mask, is_full_mask),
0,
jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape(
block_mask_shape
),
)
# data_next stores the index of the next non-empty data block in the sequence.
# The indices of empty blocks are set to 0 to avoid copying extra data when
# pipeling.
if is_dkv:
data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None]
else:
data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :]
data_next = jnp.broadcast_to(data_next, block_mask_shape)
data_next = jnp.where(is_empty_mask, 0, data_next)
partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape)
if is_dkv:
partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2)
def _downcast(array: jax.Array, max_value: int) -> jax.Array:
if array.size == 0:
return array
if array.dtype != np.int32:
raise ValueError(f'Expected int32 input, but got {array.dtype}.')
if max_value <= np.iinfo(np.int8).max:
return array.astype(np.int8)
elif max_value <= np.iinfo(np.int16).max:
return array.astype(np.int16)
else:
return array.astype(np.int32)
if downcast_smem_data:
block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2]
data_next = _downcast(
data_next, q_blocks_count if is_dkv else kv_blocks_count
)
mask_next = _downcast(mask_next, math.prod(block_mask_shape))
return (
MaskInfo(
data_next=data_next,
mask_next=mask_next,
block_mask=block_mask,
partial_mask_blocks=partial_mask_blocks,
q_sequence=None,
),
None,
)
# When used in a transformer network with multiple layers, the SplashAttention
# kernel is created several times with the same mask. Cache MaskInfo to avoid
# blowing up compile times. Ideally the size of the cache should be determined
@ -410,7 +560,7 @@ def _process_mask(
mask_id_to_heads[mask_id].append(head)
mask_id_to_head_shards[mask_id].add(head_shard)
# If we have at most one unique mask per each head shard, then we can brodcast
# If we have at most one unique mask per each head shard, then we can broadcast
# the mask to all the heads in the shard. This is the common case.
# If we have more than one mask in each head shard, then the optimization
# cannot kick in and we use one mask for each head.
@ -699,9 +849,7 @@ def _process_mask(
current_block_mask,
current_data_next,
current_mask_next,
) in zip(
block_mask_shards, data_next_shards, mask_next_shards
):
) in zip(block_mask_shards, data_next_shards, mask_next_shards):
# For dKV shrinking happens along axis Q (the rows of MaskInfo), for
# fwd and dQ shrinking happens along axis KV (the columns of MaskInfo).
if is_dkv:
@ -924,3 +1072,6 @@ def _slice_mask_info(
process_mask = functools.partial(_process_mask, is_dkv=False)
process_mask_dkv = functools.partial(_process_mask, is_dkv=True)
process_dynamic_mask = functools.partial(_process_dynamic_mask, is_dkv=False)
process_dynamic_mask_dkv = functools.partial(_process_dynamic_mask, is_dkv=True)

View File

@ -292,6 +292,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]:
return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0))
def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array:
q_seq_len, kv_seq_len = mask.masks[0].shape
full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len))
dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0)
return dynamic_mask
@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
@ -322,9 +330,10 @@ class SplashAttentionTest(PallasBaseTest):
@parameterized.product(
is_mqa=(False, True),
is_segmented=(False, True),
is_dynamic_mask=(False, True),
)
@hp.given(hps.data())
def test_splash_attention(self, is_mqa, is_segmented, data):
def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data):
seed = data.draw(seed_strategy())
key = random.key(seed)
k1, k2, k3 = random.split(key, 3)
@ -353,6 +362,8 @@ class SplashAttentionTest(PallasBaseTest):
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
if is_dynamic_mask:
mask = to_dynamic_mask(mask)
block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len))
if is_mqa:
@ -384,10 +395,11 @@ class SplashAttentionTest(PallasBaseTest):
@parameterized.product(
is_mqa=(False, True),
is_segmented=(False, True),
is_dynamic_mask=(False, True),
)
@hp.given(hps.data())
def test_splash_attention_fwd(
self, is_mqa, is_segmented, data
self, is_mqa, is_segmented, is_dynamic_mask, data
):
seed = data.draw(seed_strategy())
key = random.key(seed)
@ -416,6 +428,8 @@ class SplashAttentionTest(PallasBaseTest):
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
if is_dynamic_mask:
mask = to_dynamic_mask(mask)
block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len))
if is_mqa:
attn_ref = splash.make_masked_mqa_reference(mask)
@ -531,10 +545,17 @@ class SplashAttentionTest(PallasBaseTest):
is_segmented=(False, True),
downcast_smem_data=(False, True),
use_fused_bwd_kernel=(False, True),
use_dynamic_mask=(False, True),
)
@hp.given(hps.data())
def test_splash_attention_bwd(
self, is_mqa, is_segmented, downcast_smem_data, use_fused_bwd_kernel, data
self,
is_mqa,
is_segmented,
downcast_smem_data,
use_fused_bwd_kernel,
use_dynamic_mask,
data,
):
seed = data.draw(seed_strategy())
key = random.key(seed)
@ -563,6 +584,8 @@ class SplashAttentionTest(PallasBaseTest):
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
if use_dynamic_mask:
mask = to_dynamic_mask(mask)
block_sizes = data.draw(
block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True,
use_fused_bwd_kernel=use_fused_bwd_kernel)

View File

@ -21,6 +21,7 @@ import jax
from jax._src import test_util as jtu
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
@ -2091,6 +2092,87 @@ class SplashAttentionMaskInfoTest(jtu.JaxTestCase):
self.assertIn("softmax", str(ctx.exception))
@parameterized.parameters((False,), (True,))
def test_dynamic_mask(self, is_dkv: bool):
head_count, q_seq_len, kv_seq_len = 1, 8, 8
block_shape = (2, 4)
mask = jnp.stack([_make_causal_mask((q_seq_len, kv_seq_len))] * head_count)
process_dynamic_mask_fn = jax.jit(
mask_info_lib.process_dynamic_mask,
static_argnames=["block_shape", "is_dkv"],
)
mask_info, _ = process_dynamic_mask_fn(
mask, block_shape=block_shape, is_dkv=is_dkv
)
_expected_block_mask = np.array(
[[
[1, 0],
[1, 0],
[2, 1],
[2, 1],
]],
dtype=np.int8,
)
_expected_partial_mask_blocks = np.array(
[
[[1, 0, 0, 0], [1, 1, 0, 0]],
[[0, 0, 0, 0], [0, 0, 0, 0]],
[[1, 1, 1, 0], [1, 1, 1, 1]],
[[0, 0, 0, 0], [0, 0, 0, 0]],
[[1, 1, 1, 1], [1, 1, 1, 1]],
[[1, 0, 0, 0], [1, 1, 0, 0]],
[[1, 1, 1, 1], [1, 1, 1, 1]],
[[1, 1, 1, 0], [1, 1, 1, 1]],
],
dtype=np.bool_,
)
_expected_mask_next = np.array(
[[
[0, 0],
[2, 0],
[0, 5],
[0, 7],
]],
dtype=np.int8,
)
_expected_data_next = np.array(
[[
[0, 0],
[0, 0],
[0, 1],
[0, 1],
]],
dtype=np.int8,
)
if is_dkv:
_expected_partial_mask_blocks = _expected_partial_mask_blocks.swapaxes(
-1, -2
)
_expected_data_next = np.array(
[[
[0, 0],
[1, 0],
[2, 2],
[3, 3],
]],
dtype=np.int8,
)
self.assertArraysEqual(mask_info.block_mask, _expected_block_mask)
self.assertArraysEqual(
mask_info.partial_mask_blocks,
_expected_partial_mask_blocks,
)
self.assertArraysEqual(mask_info.mask_next, _expected_mask_next)
self.assertArraysEqual(mask_info.data_next, _expected_data_next)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())