Save residuals in the decode attention pallas kernel

This commit is contained in:
Gleb Pobudzey 2024-12-02 06:51:49 +00:00
parent bd66f5280b
commit a4e742d2fe
2 changed files with 123 additions and 25 deletions

View File

@ -143,6 +143,7 @@ def decode_attn_unbatched(
grid: tuple[int, ...] | None,
interpret: bool,
debug: bool,
return_residuals: bool
):
num_heads, head_dim = q.shape
k_seq_len, _ = k.shape
@ -215,7 +216,10 @@ def decode_attn_unbatched(
l_next = (l * correction).sum(axis=0)
eps = jnp.finfo(l_next.dtype).eps
o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps)
return o
if return_residuals:
return o, (l_next, m_next)
else:
return o
@functools.partial(
@ -230,6 +234,7 @@ def decode_attn_unbatched(
"grid",
"interpret",
"debug",
"return_residuals"
],
)
def mqa(
@ -247,6 +252,7 @@ def mqa(
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False
):
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
bs = q.shape[0]
@ -265,6 +271,7 @@ def mqa(
grid=grid,
interpret=interpret,
debug=debug,
return_residuals=return_residuals
)
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
@ -281,6 +288,7 @@ def mqa(
"grid",
"interpret",
"debug",
"return_residuals"
],
)
def gqa(
@ -298,6 +306,7 @@ def gqa(
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False,
):
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
batch_size, q_heads, head_dim = q.shape
@ -331,14 +340,23 @@ def gqa(
grid=grid,
interpret=interpret,
debug=debug,
return_residuals=return_residuals,
)
with_kv_heads = jax.vmap(inner)
o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed,
start_idx, kv_seq_len)
return o.reshape(batch_size, q_heads, head_dim)
o, *res = jax.vmap(with_kv_heads)(
q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len
)
o = o.reshape(batch_size, q_heads, head_dim)
if return_residuals:
l, m = res[0]
l = l.reshape(batch_size, q_heads)
m = m.reshape(batch_size, q_heads)
return o, (l, m)
else:
return o
@functools.partial(jax.jit, static_argnames=["sm_scale"])
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
def mqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, head_dim]
@ -346,10 +364,16 @@ def mqa_reference(
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
return_residuals=False
):
original_dtype = q.dtype
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
bs = q.shape[0]
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
if sm_scale is not None and sm_scale != 1.0:
logits = logits * sm_scale
if start_idx is not None or kv_seq_len is not None:
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
@ -358,8 +382,17 @@ def mqa_reference(
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
mask = mask[:, None, :]
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum("bns,bsd->bnd", weights, v)
m = logits.max(axis=-1)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1)
s = s / l[..., None]
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
if return_residuals:
return o, (l, m)
else:
return o
@functools.partial(jax.jit, static_argnames=["sm_scale"])
@ -387,7 +420,7 @@ def mha_reference(
return jnp.einsum("bns,bsnd->bnd", weights, v)
@functools.partial(jax.jit, static_argnames=["sm_scale"])
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
def gqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
@ -395,7 +428,11 @@ def gqa_reference(
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
return_residuals=False
):
original_dtype = q.dtype
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
bs, num_q_heads, head_dim = q.shape
num_kv_heads = k.shape[2]
@ -412,6 +449,8 @@ def gqa_reference(
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
jnp.float32
)
if sm_scale is not None and sm_scale != 1.0:
logits = logits * sm_scale
if start_idx is not None or kv_seq_len is not None:
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
@ -420,7 +459,17 @@ def gqa_reference(
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
mask = mask[:, None, None, :]
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed)
m = logits.max(axis=-1)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1)
s = s / l[..., None]
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
o = o.reshape(bs, num_q_heads, head_dim)
return o
if return_residuals:
l = l.reshape(bs, num_q_heads)
m = m.reshape(bs, num_q_heads)
return o, (l, m)
else:
return o

View File

@ -21,6 +21,7 @@ import jax
from jax import random
from jax._src import config
from jax._src import test_util as jtu
if sys.platform != "win32":
from jax.experimental.pallas.ops.gpu import decode_attention
else:
@ -48,8 +49,9 @@ class PallasBaseTest(jtu.JaxTestCase):
self.skipTest("On CPU, the test works only in interpret mode")
if jax.config.x64_enabled:
self.skipTest("The test works only in 32-bit")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
if jtu.test_device_matches(
["cuda"]
) and not jtu.is_cuda_compute_capability_at_least("8.0"):
self.skipTest("Only works on GPU with capability >= sm80")
if sys.platform == "win32":
self.skipTest("Only works on non-Windows platforms")
@ -62,8 +64,10 @@ class DecodeAttentionTest(PallasBaseTest):
@parameterized.named_parameters(*[
(
(f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_"
f"{start_idx=}_{kv_seq_len=}"),
(
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_"
f"{start_idx=}_{kv_seq_len=}_{return_residuals=}"
),
batch_size,
seq_len,
num_heads,
@ -71,6 +75,7 @@ class DecodeAttentionTest(PallasBaseTest):
kwargs,
start_idx,
kv_seq_len,
return_residuals,
)
for (
batch_size,
@ -85,6 +90,7 @@ class DecodeAttentionTest(PallasBaseTest):
]
for start_idx in [None, 123]
for kv_seq_len in [None, 250]
for return_residuals in [False, True]
])
@jax.numpy_dtype_promotion("standard")
def test_mqa(
@ -96,6 +102,7 @@ class DecodeAttentionTest(PallasBaseTest):
kwargs,
start_idx,
kv_seq_len,
return_residuals,
):
del kwargs
@ -104,16 +111,36 @@ class DecodeAttentionTest(PallasBaseTest):
k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16)
v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16)
o = decode_attention.mqa(q, k, v, start_idx=start_idx,
kv_seq_len=kv_seq_len, interpret=self.INTERPRET)
o_ref = decode_attention.mqa_reference(q, k, v, start_idx=start_idx,
kv_seq_len=kv_seq_len)
o, *res = decode_attention.mqa(
q,
k,
v,
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
interpret=self.INTERPRET,
)
o_ref, *res_ref = decode_attention.mqa_reference(
q,
k,
v,
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
)
np.testing.assert_allclose(o, o_ref, atol=0.05)
if return_residuals:
l, m = res[0]
l_ref, m_ref = res_ref[0]
np.testing.assert_allclose(l, l_ref, atol=0.05)
np.testing.assert_allclose(m, m_ref, atol=0.05)
@parameterized.named_parameters(*[
(
(f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}"
f"_{kwargs=}_{start_idx=}_{kv_seq_len=}"),
(
f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}"
f"_{kwargs=}_{start_idx=}_{kv_seq_len=}_{return_residuals=}"
),
batch_size,
seq_len,
num_q_heads,
@ -122,6 +149,7 @@ class DecodeAttentionTest(PallasBaseTest):
kwargs,
start_idx,
kv_seq_len,
return_residuals,
)
for (
batch_size,
@ -137,6 +165,7 @@ class DecodeAttentionTest(PallasBaseTest):
]
for start_idx in [None, 123]
for kv_seq_len in [None, 250]
for return_residuals in [False, True]
])
@jax.numpy_dtype_promotion("standard")
def test_gqa(
@ -149,6 +178,7 @@ class DecodeAttentionTest(PallasBaseTest):
kwargs,
start_idx,
kv_seq_len,
return_residuals,
):
del kwargs
@ -162,11 +192,30 @@ class DecodeAttentionTest(PallasBaseTest):
v = random.normal(
k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16
)
o = decode_attention.gqa(q, k, v, start_idx=start_idx,
kv_seq_len=kv_seq_len, interpret=self.INTERPRET)
o_ref = decode_attention.gqa_reference(q, k, v, start_idx=start_idx,
kv_seq_len=kv_seq_len)
o, *res = decode_attention.gqa(
q,
k,
v,
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
interpret=self.INTERPRET,
)
o_ref, *res_ref = decode_attention.gqa_reference(
q,
k,
v,
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
)
np.testing.assert_allclose(o, o_ref, atol=0.05)
if return_residuals:
l, m = res[0]
l_ref, m_ref = res_ref[0]
np.testing.assert_allclose(l, l_ref, atol=0.05)
np.testing.assert_allclose(m, m_ref, atol=0.05)
class DecodeAttentionInterpretTest(DecodeAttentionTest):
INTERPRET = True