[Pallas/TPU] Fix bug with LocalMask grid shrinking

LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space.
As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space.

PiperOrigin-RevId: 655901432
This commit is contained in:
jax authors 2024-07-25 04:05:13 -07:00 committed by jax authors
parent e14752c0ab
commit f15f9717c3
2 changed files with 78 additions and 4 deletions

View File

@ -725,7 +725,7 @@ def flash_attention_kernel(
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
_, _, should_run, should_not_mask = _next_nonzero(
global_kv_index, _, should_run, should_not_mask = _next_nonzero(
h,
i,
j,
@ -760,7 +760,11 @@ def flash_attention_kernel(
kv_segment_ids_ref,
attn_logits_soft_cap=attn_logits_soft_cap,
k_slice=slice_k,
k_offset=j * bkv + kv_compute_index * bkv_compute,
# When the iteration space is shrunk (for local attention for example),
# the kv_index program_id does not correspond to the actual coordinates
# of the KV data. Make sure to use the 'unshrunk' index (coming from the
# data_next array) when computing the mask.
k_offset=global_kv_index * bkv + kv_compute_index * bkv_compute,
bq=bq,
mask_function=mask_function,
)
@ -1282,7 +1286,7 @@ def _flash_attention_dq_kernel(
def init():
dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref)
_, _, should_run, should_not_mask = _next_nonzero(
global_kv_index, _, should_run, should_not_mask = _next_nonzero(
h, i, j, data_next_ref, block_mask_ref, mask_next_ref
)
@pl.when(should_run)
@ -1308,7 +1312,11 @@ def _flash_attention_dq_kernel(
kv_segment_ids_ref,
attn_logits_soft_cap=attn_logits_soft_cap,
k_slice=pl.ds(0, bkv),
k_offset=j * bkv,
# When the iteration space is shrunk (for local attention for example),
# the kv_index program_id does not correspond to the actual coordinates
# of the KV data. Make sure to use the 'unshrunk' index (coming from the
# data_next array) when computing the mask.
k_offset=global_kv_index * bkv,
bq=bq,
mask_function=mask_function,
)

View File

@ -28,6 +28,7 @@ from jax import random
from jax._src import test_util as jtu
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info import process_mask
import jax.numpy as jnp
import numpy as np
@ -649,6 +650,71 @@ class SplashAttentionTest(PallasBaseTest):
self._assert_allclose(dq, dq_ref, atol=2e-2, rtol=3e-2)
self._assert_allclose(dk, dk_ref, atol=2e-2, rtol=3e-2)
def test_grid_shrinking(self):
"""Make sure that grid shrinking does not change the attention output."""
class IdentityMask(mask_lib._ComputableMask):
"""Identity mask that is guaranteed to trigger grid shrinking."""
def __init__(
self,
shape: tuple[int, int],
shard_count: int = 1,
):
def identity_mask_function(q_ids, kv_ids):
return q_ids == kv_ids
super().__init__(
shape=shape,
mask_function=identity_mask_function,
shard_count=shard_count,
)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return self.shape == other.shape and np.array_equal(
self.q_sequence, other.q_sequence
)
def __hash__(self):
return hash((
type(self),
self.shape,
self.q_sequence.tobytes() if self.q_sequence is not None else None,
))
# Use a sequence length greater than the default block size to trigger
# the grid shrinking logic.
seq_len = 256
head_dim = 128
key = random.key(42)
k1, k2, k3 = random.split(key, 3)
q = random.uniform(k1, (1, seq_len, head_dim), dtype=jnp.float32)
k = random.uniform(k2, (seq_len, head_dim), dtype=jnp.float32)
v = random.uniform(k3, (seq_len, head_dim), dtype=jnp.float32)
identity_mask = mask_lib.MultiHeadMask([IdentityMask((seq_len, seq_len))])
process_mask_path = "jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info.process_mask"
process_mask_shrink = lambda *args, **kwargs: process_mask(
*args, **kwargs, shrink_grid=True
)
process_mask_no_shrink = lambda *args, **kwargs: process_mask(
*args, **kwargs, shrink_grid=False
)
with unittest.mock.patch(process_mask_path, process_mask_shrink):
shrink_out = splash.make_splash_mqa_single_device(identity_mask)(q, k, v)
with unittest.mock.patch(process_mask_path, process_mask_no_shrink):
no_shrink_out = splash.make_splash_mqa_single_device(identity_mask)(
q, k, v
)
np.testing.assert_array_equal(shrink_out, no_shrink_out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())