mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support dynamic masks in splash attention
This commit is contained in:
parent
95cb0eb1c9
commit
c0d23af42c
@ -46,7 +46,7 @@ NT_DIM_NUMBERS = (((1,), (1,)), ((), ())) # RHS transposed
|
||||
class SegmentIds(NamedTuple):
|
||||
"""SegmentIds for Q and KV sequences.
|
||||
|
||||
SegmentIds are a mechanims to ensure that there is no cross-attention between
|
||||
SegmentIds are a mechanism to ensure that there is no cross-attention between
|
||||
segments (fraction of a sequence) that have been concatenated together into a
|
||||
sequence. Each array is a list of ids (integers). Only tokens with the same
|
||||
id are allowed to attend to each other.
|
||||
@ -2392,7 +2392,7 @@ class SplashAttentionKernel:
|
||||
|
||||
|
||||
def _make_splash_attention(
|
||||
mask: np.ndarray | mask_lib.MultiHeadMask,
|
||||
mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask,
|
||||
*,
|
||||
block_sizes: BlockSizes | None = None,
|
||||
is_mqa: bool,
|
||||
@ -2415,14 +2415,26 @@ def _make_splash_attention(
|
||||
|
||||
if block_sizes is None:
|
||||
block_sizes = BlockSizes.get_default()
|
||||
fwd_mask_info, mask_function_fwd = mask_info_lib.process_mask(
|
||||
|
||||
process_mask_fn = (
|
||||
mask_info_lib.process_dynamic_mask
|
||||
if isinstance(mask, jax.Array)
|
||||
else mask_info_lib.process_mask
|
||||
)
|
||||
|
||||
process_mask_dvk_fn = (
|
||||
mask_info_lib.process_dynamic_mask_dkv
|
||||
if isinstance(mask, jax.Array)
|
||||
else mask_info_lib.process_mask_dkv
|
||||
)
|
||||
|
||||
fwd_mask_info, mask_function_fwd = process_mask_fn(
|
||||
mask,
|
||||
(block_sizes.block_q, block_sizes.block_kv),
|
||||
downcast_smem_data=downcast_smem_data,
|
||||
head_shards=head_shards,
|
||||
q_seq_shards=q_seq_shards,
|
||||
)
|
||||
|
||||
fwd_mask_info = tree_util.tree_map(jnp.array, fwd_mask_info)
|
||||
|
||||
dq_mask_info = None
|
||||
@ -2432,7 +2444,7 @@ def _make_splash_attention(
|
||||
dq_mask_info = None
|
||||
else:
|
||||
bq_dq, bkv_dq = block_sizes.block_q_dq, block_sizes.block_kv_dq
|
||||
dq_mask_info, mask_function_dq = mask_info_lib.process_mask(
|
||||
dq_mask_info, mask_function_dq = process_mask_fn(
|
||||
mask,
|
||||
(bq_dq, bkv_dq),
|
||||
downcast_smem_data=downcast_smem_data,
|
||||
@ -2442,7 +2454,7 @@ def _make_splash_attention(
|
||||
assert (mask_function_fwd is None) == (mask_function_dq is None)
|
||||
dq_mask_info = tree_util.tree_map(jnp.array, dq_mask_info)
|
||||
bq_dkv, bkv_dkv = block_sizes.block_q_dkv, block_sizes.block_kv_dkv
|
||||
dkv_mask_info, mask_function_dkv = mask_info_lib.process_mask_dkv(
|
||||
dkv_mask_info, mask_function_dkv = process_mask_dvk_fn(
|
||||
mask,
|
||||
(bq_dkv, bkv_dkv),
|
||||
downcast_smem_data=downcast_smem_data,
|
||||
|
@ -18,9 +18,13 @@ from __future__ import annotations
|
||||
import collections
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import math
|
||||
from typing import NamedTuple
|
||||
|
||||
import jax
|
||||
from jax import util as jax_util
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
# mypy: ignore-errors
|
||||
@ -65,10 +69,10 @@ class MaskInfo(NamedTuple):
|
||||
causal this is just np.arange(q_sequence_length).
|
||||
"""
|
||||
|
||||
data_next: np.ndarray | None
|
||||
mask_next: np.ndarray | None
|
||||
block_mask: np.ndarray | None
|
||||
partial_mask_blocks: np.ndarray | None
|
||||
data_next: np.ndarray | jax.Array | None
|
||||
mask_next: np.ndarray | jax.Array | None
|
||||
block_mask: np.ndarray | jax.Array | None
|
||||
partial_mask_blocks: np.ndarray | jax.Array | None
|
||||
q_sequence: np.ndarray | None
|
||||
|
||||
|
||||
@ -245,7 +249,7 @@ def _get_mask_info_for_shard(
|
||||
mask_next = np.zeros(output_shape, dtype=np.int32)
|
||||
data_next = np.zeros(output_shape, dtype=np.int32)
|
||||
|
||||
# If the mask is completelly zero'd out return freshly initialized outputs.
|
||||
# If the mask is completely zero'd out return freshly initialized outputs.
|
||||
if not data_coords:
|
||||
return data_next, mask_next
|
||||
|
||||
@ -304,6 +308,152 @@ def _get_mask_info_for_shard(
|
||||
return data_next, mask_next
|
||||
|
||||
|
||||
def _process_dynamic_mask(
|
||||
mask: jax.Array,
|
||||
block_shape: tuple[int, int],
|
||||
is_dkv: bool,
|
||||
*,
|
||||
downcast_smem_data: bool = True,
|
||||
head_shards: int = 1,
|
||||
q_seq_shards: int = 1,
|
||||
shrink_grid: bool = True,
|
||||
) -> tuple[MaskInfo, None]:
|
||||
"""Similar to `_process_mask` but the mask must be a dynamic array.
|
||||
|
||||
Since the mask is dynamic, we can't know the exact number of partial mask
|
||||
blocks at trace time. Therefore, the entire mask is materialized in
|
||||
`partial_mask_blocks`.
|
||||
|
||||
Note that we can still populate MaskInfo to skip fully-masked blocks.
|
||||
|
||||
Args:
|
||||
mask: A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense
|
||||
mask to process.
|
||||
block_shape: A Tuple[int, int] representing the shape of the Pallas grid
|
||||
block.
|
||||
is_dkv: True if we are processing the dKV mask
|
||||
downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to
|
||||
a data type smaller than np.int32 (if possible).
|
||||
head_shards: Number of head shards of the mesh in which the kernel is
|
||||
launched.
|
||||
q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is
|
||||
launched.
|
||||
shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored.
|
||||
|
||||
Returns:
|
||||
`MaskInfo`, a sparse representation of the dense mask.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input mask is invalid or the block sizes are not
|
||||
compatible with the mask sizes.
|
||||
"""
|
||||
|
||||
del shrink_grid
|
||||
|
||||
# TODO(pobudzey): Properly support sharding.
|
||||
if head_shards != 1 or q_seq_shards != 1:
|
||||
raise ValueError('Dynamic mask processing does not support sharding.')
|
||||
|
||||
if len(mask.shape) != 3:
|
||||
raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.')
|
||||
|
||||
if mask.dtype != jnp.bool:
|
||||
raise ValueError(f'Expected a bool mask, instead got: {mask.dtype}.')
|
||||
|
||||
head_count, q_seq_len, kv_seq_len = mask.shape
|
||||
q_block_size, kv_block_size = block_shape
|
||||
q_blocks_count, q_mod = divmod(q_seq_len, q_block_size)
|
||||
kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_block_size)
|
||||
|
||||
if q_mod != 0:
|
||||
raise ValueError(f'{q_block_size=} should divide {q_seq_len=}.')
|
||||
if kv_mod != 0:
|
||||
raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.')
|
||||
|
||||
block_mask_shape = (
|
||||
head_count,
|
||||
q_blocks_count,
|
||||
kv_blocks_count,
|
||||
)
|
||||
|
||||
# Tile the last 2 dimensions of the mask into 2D tiles of size `block_shape`.
|
||||
partial_mask_blocks = (
|
||||
mask.reshape(
|
||||
head_count,
|
||||
q_blocks_count,
|
||||
q_block_size,
|
||||
kv_blocks_count,
|
||||
kv_block_size,
|
||||
)
|
||||
.swapaxes(-2, -3)
|
||||
.astype(np.bool_)
|
||||
)
|
||||
|
||||
# The block mask is 2 for all blocks with all entries set to True and 1 for
|
||||
# blocks with a mix of True and False entries.
|
||||
is_full_mask = jnp.all(partial_mask_blocks, axis=(-1, -2))
|
||||
is_empty_mask = jnp.logical_not(jnp.any(partial_mask_blocks, axis=(-1, -2)))
|
||||
|
||||
block_mask = jnp.ones(block_mask_shape, dtype=np.int32)
|
||||
block_mask = jnp.where(is_full_mask, 2, block_mask)
|
||||
block_mask = jnp.where(is_empty_mask, 0, block_mask)
|
||||
|
||||
# TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline.
|
||||
mask_next = jnp.where(
|
||||
jnp.logical_or(is_empty_mask, is_full_mask),
|
||||
0,
|
||||
jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape(
|
||||
block_mask_shape
|
||||
),
|
||||
)
|
||||
|
||||
# data_next stores the index of the next non-empty data block in the sequence.
|
||||
# The indices of empty blocks are set to 0 to avoid copying extra data when
|
||||
# pipeling.
|
||||
if is_dkv:
|
||||
data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None]
|
||||
else:
|
||||
data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :]
|
||||
data_next = jnp.broadcast_to(data_next, block_mask_shape)
|
||||
data_next = jnp.where(is_empty_mask, 0, data_next)
|
||||
|
||||
partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape)
|
||||
if is_dkv:
|
||||
partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2)
|
||||
|
||||
def _downcast(array: jax.Array, max_value: int) -> jax.Array:
|
||||
if array.size == 0:
|
||||
return array
|
||||
|
||||
if array.dtype != np.int32:
|
||||
raise ValueError(f'Expected int32 input, but got {array.dtype}.')
|
||||
|
||||
if max_value <= np.iinfo(np.int8).max:
|
||||
return array.astype(np.int8)
|
||||
elif max_value <= np.iinfo(np.int16).max:
|
||||
return array.astype(np.int16)
|
||||
else:
|
||||
return array.astype(np.int32)
|
||||
|
||||
if downcast_smem_data:
|
||||
block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2]
|
||||
data_next = _downcast(
|
||||
data_next, q_blocks_count if is_dkv else kv_blocks_count
|
||||
)
|
||||
mask_next = _downcast(mask_next, math.prod(block_mask_shape))
|
||||
|
||||
return (
|
||||
MaskInfo(
|
||||
data_next=data_next,
|
||||
mask_next=mask_next,
|
||||
block_mask=block_mask,
|
||||
partial_mask_blocks=partial_mask_blocks,
|
||||
q_sequence=None,
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# When used in a transformer network with multiple layers, the SplashAttention
|
||||
# kernel is created several times with the same mask. Cache MaskInfo to avoid
|
||||
# blowing up compile times. Ideally the size of the cache should be determined
|
||||
@ -410,7 +560,7 @@ def _process_mask(
|
||||
mask_id_to_heads[mask_id].append(head)
|
||||
mask_id_to_head_shards[mask_id].add(head_shard)
|
||||
|
||||
# If we have at most one unique mask per each head shard, then we can brodcast
|
||||
# If we have at most one unique mask per each head shard, then we can broadcast
|
||||
# the mask to all the heads in the shard. This is the common case.
|
||||
# If we have more than one mask in each head shard, then the optimization
|
||||
# cannot kick in and we use one mask for each head.
|
||||
@ -699,9 +849,7 @@ def _process_mask(
|
||||
current_block_mask,
|
||||
current_data_next,
|
||||
current_mask_next,
|
||||
) in zip(
|
||||
block_mask_shards, data_next_shards, mask_next_shards
|
||||
):
|
||||
) in zip(block_mask_shards, data_next_shards, mask_next_shards):
|
||||
# For dKV shrinking happens along axis Q (the rows of MaskInfo), for
|
||||
# fwd and dQ shrinking happens along axis KV (the columns of MaskInfo).
|
||||
if is_dkv:
|
||||
@ -924,3 +1072,6 @@ def _slice_mask_info(
|
||||
|
||||
process_mask = functools.partial(_process_mask, is_dkv=False)
|
||||
process_mask_dkv = functools.partial(_process_mask, is_dkv=True)
|
||||
|
||||
process_dynamic_mask = functools.partial(_process_dynamic_mask, is_dkv=False)
|
||||
process_dynamic_mask_dkv = functools.partial(_process_dynamic_mask, is_dkv=True)
|
||||
|
@ -292,6 +292,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]:
|
||||
return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0))
|
||||
|
||||
|
||||
def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array:
|
||||
q_seq_len, kv_seq_len = mask.masks[0].shape
|
||||
full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len))
|
||||
dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0)
|
||||
|
||||
return dynamic_mask
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasBaseTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
@ -322,9 +330,10 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
@parameterized.product(
|
||||
is_mqa=(False, True),
|
||||
is_segmented=(False, True),
|
||||
is_dynamic_mask=(False, True),
|
||||
)
|
||||
@hp.given(hps.data())
|
||||
def test_splash_attention(self, is_mqa, is_segmented, data):
|
||||
def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data):
|
||||
seed = data.draw(seed_strategy())
|
||||
key = random.key(seed)
|
||||
k1, k2, k3 = random.split(key, 3)
|
||||
@ -353,6 +362,8 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
|
||||
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
|
||||
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
|
||||
if is_dynamic_mask:
|
||||
mask = to_dynamic_mask(mask)
|
||||
block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len))
|
||||
|
||||
if is_mqa:
|
||||
@ -384,10 +395,11 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
@parameterized.product(
|
||||
is_mqa=(False, True),
|
||||
is_segmented=(False, True),
|
||||
is_dynamic_mask=(False, True),
|
||||
)
|
||||
@hp.given(hps.data())
|
||||
def test_splash_attention_fwd(
|
||||
self, is_mqa, is_segmented, data
|
||||
self, is_mqa, is_segmented, is_dynamic_mask, data
|
||||
):
|
||||
seed = data.draw(seed_strategy())
|
||||
key = random.key(seed)
|
||||
@ -416,6 +428,8 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
|
||||
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
|
||||
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
|
||||
if is_dynamic_mask:
|
||||
mask = to_dynamic_mask(mask)
|
||||
block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len))
|
||||
if is_mqa:
|
||||
attn_ref = splash.make_masked_mqa_reference(mask)
|
||||
@ -531,10 +545,17 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
is_segmented=(False, True),
|
||||
downcast_smem_data=(False, True),
|
||||
use_fused_bwd_kernel=(False, True),
|
||||
use_dynamic_mask=(False, True),
|
||||
)
|
||||
@hp.given(hps.data())
|
||||
def test_splash_attention_bwd(
|
||||
self, is_mqa, is_segmented, downcast_smem_data, use_fused_bwd_kernel, data
|
||||
self,
|
||||
is_mqa,
|
||||
is_segmented,
|
||||
downcast_smem_data,
|
||||
use_fused_bwd_kernel,
|
||||
use_dynamic_mask,
|
||||
data,
|
||||
):
|
||||
seed = data.draw(seed_strategy())
|
||||
key = random.key(seed)
|
||||
@ -563,6 +584,8 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
|
||||
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
|
||||
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
|
||||
if use_dynamic_mask:
|
||||
mask = to_dynamic_mask(mask)
|
||||
block_sizes = data.draw(
|
||||
block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True,
|
||||
use_fused_bwd_kernel=use_fused_bwd_kernel)
|
||||
|
@ -21,6 +21,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
@ -2091,6 +2092,87 @@ class SplashAttentionMaskInfoTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIn("softmax", str(ctx.exception))
|
||||
|
||||
@parameterized.parameters((False,), (True,))
|
||||
def test_dynamic_mask(self, is_dkv: bool):
|
||||
head_count, q_seq_len, kv_seq_len = 1, 8, 8
|
||||
block_shape = (2, 4)
|
||||
|
||||
mask = jnp.stack([_make_causal_mask((q_seq_len, kv_seq_len))] * head_count)
|
||||
|
||||
process_dynamic_mask_fn = jax.jit(
|
||||
mask_info_lib.process_dynamic_mask,
|
||||
static_argnames=["block_shape", "is_dkv"],
|
||||
)
|
||||
mask_info, _ = process_dynamic_mask_fn(
|
||||
mask, block_shape=block_shape, is_dkv=is_dkv
|
||||
)
|
||||
|
||||
_expected_block_mask = np.array(
|
||||
[[
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
[2, 1],
|
||||
[2, 1],
|
||||
]],
|
||||
dtype=np.int8,
|
||||
)
|
||||
|
||||
_expected_partial_mask_blocks = np.array(
|
||||
[
|
||||
[[1, 0, 0, 0], [1, 1, 0, 0]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[1, 1, 1, 0], [1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
[[1, 0, 0, 0], [1, 1, 0, 0]],
|
||||
[[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0], [1, 1, 1, 1]],
|
||||
],
|
||||
dtype=np.bool_,
|
||||
)
|
||||
|
||||
_expected_mask_next = np.array(
|
||||
[[
|
||||
[0, 0],
|
||||
[2, 0],
|
||||
[0, 5],
|
||||
[0, 7],
|
||||
]],
|
||||
dtype=np.int8,
|
||||
)
|
||||
|
||||
_expected_data_next = np.array(
|
||||
[[
|
||||
[0, 0],
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[0, 1],
|
||||
]],
|
||||
dtype=np.int8,
|
||||
)
|
||||
|
||||
if is_dkv:
|
||||
_expected_partial_mask_blocks = _expected_partial_mask_blocks.swapaxes(
|
||||
-1, -2
|
||||
)
|
||||
_expected_data_next = np.array(
|
||||
[[
|
||||
[0, 0],
|
||||
[1, 0],
|
||||
[2, 2],
|
||||
[3, 3],
|
||||
]],
|
||||
dtype=np.int8,
|
||||
)
|
||||
|
||||
self.assertArraysEqual(mask_info.block_mask, _expected_block_mask)
|
||||
self.assertArraysEqual(
|
||||
mask_info.partial_mask_blocks,
|
||||
_expected_partial_mask_blocks,
|
||||
)
|
||||
self.assertArraysEqual(mask_info.mask_next, _expected_mask_next)
|
||||
self.assertArraysEqual(mask_info.data_next, _expected_data_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user