mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas/TPU] Fix bug with LocalMask grid shrinking
LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space. As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space. PiperOrigin-RevId: 655901432
This commit is contained in:
parent
e14752c0ab
commit
f15f9717c3
@ -725,7 +725,7 @@ def flash_attention_kernel(
|
||||
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
|
||||
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
|
||||
|
||||
_, _, should_run, should_not_mask = _next_nonzero(
|
||||
global_kv_index, _, should_run, should_not_mask = _next_nonzero(
|
||||
h,
|
||||
i,
|
||||
j,
|
||||
@ -760,7 +760,11 @@ def flash_attention_kernel(
|
||||
kv_segment_ids_ref,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
k_slice=slice_k,
|
||||
k_offset=j * bkv + kv_compute_index * bkv_compute,
|
||||
# When the iteration space is shrunk (for local attention for example),
|
||||
# the kv_index program_id does not correspond to the actual coordinates
|
||||
# of the KV data. Make sure to use the 'unshrunk' index (coming from the
|
||||
# data_next array) when computing the mask.
|
||||
k_offset=global_kv_index * bkv + kv_compute_index * bkv_compute,
|
||||
bq=bq,
|
||||
mask_function=mask_function,
|
||||
)
|
||||
@ -1282,7 +1286,7 @@ def _flash_attention_dq_kernel(
|
||||
def init():
|
||||
dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref)
|
||||
|
||||
_, _, should_run, should_not_mask = _next_nonzero(
|
||||
global_kv_index, _, should_run, should_not_mask = _next_nonzero(
|
||||
h, i, j, data_next_ref, block_mask_ref, mask_next_ref
|
||||
)
|
||||
@pl.when(should_run)
|
||||
@ -1308,7 +1312,11 @@ def _flash_attention_dq_kernel(
|
||||
kv_segment_ids_ref,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
k_slice=pl.ds(0, bkv),
|
||||
k_offset=j * bkv,
|
||||
# When the iteration space is shrunk (for local attention for example),
|
||||
# the kv_index program_id does not correspond to the actual coordinates
|
||||
# of the KV data. Make sure to use the 'unshrunk' index (coming from the
|
||||
# data_next array) when computing the mask.
|
||||
k_offset=global_kv_index * bkv,
|
||||
bq=bq,
|
||||
mask_function=mask_function,
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ from jax import random
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info import process_mask
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -649,6 +650,71 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
self._assert_allclose(dq, dq_ref, atol=2e-2, rtol=3e-2)
|
||||
self._assert_allclose(dk, dk_ref, atol=2e-2, rtol=3e-2)
|
||||
|
||||
def test_grid_shrinking(self):
|
||||
"""Make sure that grid shrinking does not change the attention output."""
|
||||
|
||||
class IdentityMask(mask_lib._ComputableMask):
|
||||
"""Identity mask that is guaranteed to trigger grid shrinking."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: tuple[int, int],
|
||||
shard_count: int = 1,
|
||||
):
|
||||
def identity_mask_function(q_ids, kv_ids):
|
||||
return q_ids == kv_ids
|
||||
|
||||
super().__init__(
|
||||
shape=shape,
|
||||
mask_function=identity_mask_function,
|
||||
shard_count=shard_count,
|
||||
)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return self.shape == other.shape and np.array_equal(
|
||||
self.q_sequence, other.q_sequence
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
type(self),
|
||||
self.shape,
|
||||
self.q_sequence.tobytes() if self.q_sequence is not None else None,
|
||||
))
|
||||
|
||||
# Use a sequence length greater than the default block size to trigger
|
||||
# the grid shrinking logic.
|
||||
seq_len = 256
|
||||
head_dim = 128
|
||||
key = random.key(42)
|
||||
k1, k2, k3 = random.split(key, 3)
|
||||
q = random.uniform(k1, (1, seq_len, head_dim), dtype=jnp.float32)
|
||||
k = random.uniform(k2, (seq_len, head_dim), dtype=jnp.float32)
|
||||
v = random.uniform(k3, (seq_len, head_dim), dtype=jnp.float32)
|
||||
|
||||
identity_mask = mask_lib.MultiHeadMask([IdentityMask((seq_len, seq_len))])
|
||||
|
||||
process_mask_path = "jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info.process_mask"
|
||||
process_mask_shrink = lambda *args, **kwargs: process_mask(
|
||||
*args, **kwargs, shrink_grid=True
|
||||
)
|
||||
process_mask_no_shrink = lambda *args, **kwargs: process_mask(
|
||||
*args, **kwargs, shrink_grid=False
|
||||
)
|
||||
|
||||
with unittest.mock.patch(process_mask_path, process_mask_shrink):
|
||||
shrink_out = splash.make_splash_mqa_single_device(identity_mask)(q, k, v)
|
||||
|
||||
with unittest.mock.patch(process_mask_path, process_mask_no_shrink):
|
||||
no_shrink_out = splash.make_splash_mqa_single_device(identity_mask)(
|
||||
q, k, v
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(shrink_out, no_shrink_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user