mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Save residuals in the decode attention pallas kernel
This commit is contained in:
parent
bd66f5280b
commit
a4e742d2fe
@ -143,6 +143,7 @@ def decode_attn_unbatched(
|
||||
grid: tuple[int, ...] | None,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
return_residuals: bool
|
||||
):
|
||||
num_heads, head_dim = q.shape
|
||||
k_seq_len, _ = k.shape
|
||||
@ -215,7 +216,10 @@ def decode_attn_unbatched(
|
||||
l_next = (l * correction).sum(axis=0)
|
||||
eps = jnp.finfo(l_next.dtype).eps
|
||||
o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps)
|
||||
return o
|
||||
if return_residuals:
|
||||
return o, (l_next, m_next)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(
|
||||
@ -230,6 +234,7 @@ def decode_attn_unbatched(
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals"
|
||||
],
|
||||
)
|
||||
def mqa(
|
||||
@ -247,6 +252,7 @@ def mqa(
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False
|
||||
):
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
bs = q.shape[0]
|
||||
@ -265,6 +271,7 @@ def mqa(
|
||||
grid=grid,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
return_residuals=return_residuals
|
||||
)
|
||||
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
|
||||
|
||||
@ -281,6 +288,7 @@ def mqa(
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals"
|
||||
],
|
||||
)
|
||||
def gqa(
|
||||
@ -298,6 +306,7 @@ def gqa(
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False,
|
||||
):
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
batch_size, q_heads, head_dim = q.shape
|
||||
@ -331,14 +340,23 @@ def gqa(
|
||||
grid=grid,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
return_residuals=return_residuals,
|
||||
)
|
||||
with_kv_heads = jax.vmap(inner)
|
||||
o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed,
|
||||
start_idx, kv_seq_len)
|
||||
return o.reshape(batch_size, q_heads, head_dim)
|
||||
o, *res = jax.vmap(with_kv_heads)(
|
||||
q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len
|
||||
)
|
||||
o = o.reshape(batch_size, q_heads, head_dim)
|
||||
if return_residuals:
|
||||
l, m = res[0]
|
||||
l = l.reshape(batch_size, q_heads)
|
||||
m = m.reshape(batch_size, q_heads)
|
||||
return o, (l, m)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale"])
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
|
||||
def mqa_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, head_dim]
|
||||
@ -346,10 +364,16 @@ def mqa_reference(
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
return_residuals=False
|
||||
):
|
||||
original_dtype = q.dtype
|
||||
q = q.astype(jnp.float32)
|
||||
k = k.astype(jnp.float32)
|
||||
bs = q.shape[0]
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
|
||||
if sm_scale is not None and sm_scale != 1.0:
|
||||
logits = logits * sm_scale
|
||||
if start_idx is not None or kv_seq_len is not None:
|
||||
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
|
||||
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
|
||||
@ -358,8 +382,17 @@ def mqa_reference(
|
||||
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
|
||||
mask = mask[:, None, :]
|
||||
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
|
||||
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
|
||||
return jnp.einsum("bns,bsd->bnd", weights, v)
|
||||
|
||||
m = logits.max(axis=-1)
|
||||
s = jnp.exp(logits - m[..., None])
|
||||
l = s.sum(axis=-1)
|
||||
s = s / l[..., None]
|
||||
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
|
||||
|
||||
if return_residuals:
|
||||
return o, (l, m)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale"])
|
||||
@ -387,7 +420,7 @@ def mha_reference(
|
||||
return jnp.einsum("bns,bsnd->bnd", weights, v)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale"])
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
|
||||
def gqa_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, num_k_heads, head_dim]
|
||||
@ -395,7 +428,11 @@ def gqa_reference(
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
return_residuals=False
|
||||
):
|
||||
original_dtype = q.dtype
|
||||
q = q.astype(jnp.float32)
|
||||
k = k.astype(jnp.float32)
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
bs, num_q_heads, head_dim = q.shape
|
||||
num_kv_heads = k.shape[2]
|
||||
@ -412,6 +449,8 @@ def gqa_reference(
|
||||
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
|
||||
jnp.float32
|
||||
)
|
||||
if sm_scale is not None and sm_scale != 1.0:
|
||||
logits = logits * sm_scale
|
||||
if start_idx is not None or kv_seq_len is not None:
|
||||
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
|
||||
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
|
||||
@ -420,7 +459,17 @@ def gqa_reference(
|
||||
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
|
||||
mask = mask[:, None, None, :]
|
||||
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
|
||||
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
|
||||
o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed)
|
||||
|
||||
m = logits.max(axis=-1)
|
||||
s = jnp.exp(logits - m[..., None])
|
||||
l = s.sum(axis=-1)
|
||||
s = s / l[..., None]
|
||||
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
|
||||
o = o.reshape(bs, num_q_heads, head_dim)
|
||||
return o
|
||||
|
||||
if return_residuals:
|
||||
l = l.reshape(bs, num_q_heads)
|
||||
m = m.reshape(bs, num_q_heads)
|
||||
return o, (l, m)
|
||||
else:
|
||||
return o
|
||||
|
@ -21,6 +21,7 @@ import jax
|
||||
from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
if sys.platform != "win32":
|
||||
from jax.experimental.pallas.ops.gpu import decode_attention
|
||||
else:
|
||||
@ -48,8 +49,9 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
self.skipTest("On CPU, the test works only in interpret mode")
|
||||
if jax.config.x64_enabled:
|
||||
self.skipTest("The test works only in 32-bit")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
if jtu.test_device_matches(
|
||||
["cuda"]
|
||||
) and not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
self.skipTest("Only works on GPU with capability >= sm80")
|
||||
if sys.platform == "win32":
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
@ -62,8 +64,10 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
|
||||
@parameterized.named_parameters(*[
|
||||
(
|
||||
(f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_"
|
||||
f"{start_idx=}_{kv_seq_len=}"),
|
||||
(
|
||||
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_"
|
||||
f"{start_idx=}_{kv_seq_len=}_{return_residuals=}"
|
||||
),
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
@ -71,6 +75,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
kwargs,
|
||||
start_idx,
|
||||
kv_seq_len,
|
||||
return_residuals,
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
@ -85,6 +90,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
]
|
||||
for start_idx in [None, 123]
|
||||
for kv_seq_len in [None, 250]
|
||||
for return_residuals in [False, True]
|
||||
])
|
||||
@jax.numpy_dtype_promotion("standard")
|
||||
def test_mqa(
|
||||
@ -96,6 +102,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
kwargs,
|
||||
start_idx,
|
||||
kv_seq_len,
|
||||
return_residuals,
|
||||
):
|
||||
del kwargs
|
||||
|
||||
@ -104,16 +111,36 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16)
|
||||
v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16)
|
||||
|
||||
o = decode_attention.mqa(q, k, v, start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len, interpret=self.INTERPRET)
|
||||
o_ref = decode_attention.mqa_reference(q, k, v, start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len)
|
||||
o, *res = decode_attention.mqa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o_ref, *res_ref = decode_attention.mqa_reference(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
if return_residuals:
|
||||
l, m = res[0]
|
||||
l_ref, m_ref = res_ref[0]
|
||||
np.testing.assert_allclose(l, l_ref, atol=0.05)
|
||||
np.testing.assert_allclose(m, m_ref, atol=0.05)
|
||||
|
||||
@parameterized.named_parameters(*[
|
||||
(
|
||||
(f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}"
|
||||
f"_{kwargs=}_{start_idx=}_{kv_seq_len=}"),
|
||||
(
|
||||
f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}"
|
||||
f"_{kwargs=}_{start_idx=}_{kv_seq_len=}_{return_residuals=}"
|
||||
),
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
@ -122,6 +149,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
kwargs,
|
||||
start_idx,
|
||||
kv_seq_len,
|
||||
return_residuals,
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
@ -137,6 +165,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
]
|
||||
for start_idx in [None, 123]
|
||||
for kv_seq_len in [None, 250]
|
||||
for return_residuals in [False, True]
|
||||
])
|
||||
@jax.numpy_dtype_promotion("standard")
|
||||
def test_gqa(
|
||||
@ -149,6 +178,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
kwargs,
|
||||
start_idx,
|
||||
kv_seq_len,
|
||||
return_residuals,
|
||||
):
|
||||
del kwargs
|
||||
|
||||
@ -162,11 +192,30 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
v = random.normal(
|
||||
k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16
|
||||
)
|
||||
o = decode_attention.gqa(q, k, v, start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len, interpret=self.INTERPRET)
|
||||
o_ref = decode_attention.gqa_reference(q, k, v, start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len)
|
||||
o, *res = decode_attention.gqa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o_ref, *res_ref = decode_attention.gqa_reference(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
if return_residuals:
|
||||
l, m = res[0]
|
||||
l_ref, m_ref = res_ref[0]
|
||||
np.testing.assert_allclose(l, l_ref, atol=0.05)
|
||||
np.testing.assert_allclose(m, m_ref, atol=0.05)
|
||||
|
||||
|
||||
class DecodeAttentionInterpretTest(DecodeAttentionTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user