Merge pull request #26715 from Rifur13:normalize

PiperOrigin-RevId: 730667196
This commit is contained in:
jax authors 2025-02-24 17:54:06 -08:00
commit 41faf51a16
2 changed files with 41 additions and 12 deletions

View File

@ -143,7 +143,8 @@ def decode_attn_unbatched(
grid: tuple[int, ...] | None,
interpret: bool,
debug: bool,
return_residuals: bool
return_residuals: bool,
normalize_output: bool
):
num_heads, head_dim = q.shape
k_seq_len, _ = k.shape
@ -218,7 +219,9 @@ def decode_attn_unbatched(
o = o * correction[:, :, None].astype(o.dtype)
l_next = (l * correction).sum(axis=0)
eps = jnp.finfo(l_next.dtype).eps
o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps)
o = o.sum(axis=0)
if normalize_output:
o /= (l_next[:, None].astype(o.dtype) + eps)
if return_residuals:
return o, (l_next, m_next)
else:
@ -237,7 +240,8 @@ def decode_attn_unbatched(
"grid",
"interpret",
"debug",
"return_residuals"
"return_residuals",
"normalize_output"
],
)
def mqa(
@ -255,7 +259,8 @@ def mqa(
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False
return_residuals: bool = False,
normalize_output: bool = True,
):
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
bs = q.shape[0]
@ -274,7 +279,8 @@ def mqa(
grid=grid,
interpret=interpret,
debug=debug,
return_residuals=return_residuals
return_residuals=return_residuals,
normalize_output=normalize_output,
)
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
@ -291,7 +297,8 @@ def mqa(
"grid",
"interpret",
"debug",
"return_residuals"
"return_residuals",
"normalize_output"
],
)
def gqa(
@ -310,7 +317,12 @@ def gqa(
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False,
normalize_output: bool = True,
):
if not normalize_output and not return_residuals:
raise NotImplementedError(
"When normalize_output is False, attention residuals must be returned."
)
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
batch_size, q_heads, head_dim = q.shape
k_seq_len, kv_heads = k.shape[1], k.shape[2]
@ -344,6 +356,7 @@ def gqa(
interpret=interpret,
debug=debug,
return_residuals=return_residuals,
normalize_output=normalize_output,
)
with_kv_heads = jax.vmap(inner)
o, *res = jax.vmap(with_kv_heads)(
@ -359,7 +372,10 @@ def gqa(
return o
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
@functools.partial(
jax.jit,
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
)
def mqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, head_dim]
@ -367,7 +383,8 @@ def mqa_reference(
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
return_residuals=False
return_residuals=False,
normalize_output=True,
):
original_dtype = q.dtype
q = q.astype(jnp.float32)
@ -389,7 +406,8 @@ def mqa_reference(
m = logits.max(axis=-1)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1)
s = s / l[..., None]
if normalize_output:
s = s / l[..., None]
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
if return_residuals:
@ -423,7 +441,10 @@ def mha_reference(
return jnp.einsum("bns,bsnd->bnd", weights, v)
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
@functools.partial(
jax.jit,
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
)
def gqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
@ -431,7 +452,8 @@ def gqa_reference(
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
return_residuals=False
return_residuals=False,
normalize_output=True
):
original_dtype = q.dtype
q = q.astype(jnp.float32)
@ -466,7 +488,8 @@ def gqa_reference(
m = logits.max(axis=-1)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1)
s = s / l[..., None]
if normalize_output:
s = s / l[..., None]
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
o = o.reshape(bs, num_q_heads, head_dim)

View File

@ -105,6 +105,7 @@ class DecodeAttentionTest(PallasBaseTest):
return_residuals,
):
del kwargs
normalize_output = not return_residuals
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16)
@ -118,6 +119,7 @@ class DecodeAttentionTest(PallasBaseTest):
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
normalize_output=normalize_output,
interpret=self.INTERPRET,
)
o_ref, *res_ref = decode_attention.mqa_reference(
@ -127,6 +129,7 @@ class DecodeAttentionTest(PallasBaseTest):
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
normalize_output=normalize_output
)
np.testing.assert_allclose(o, o_ref, atol=0.05)
if return_residuals:
@ -181,6 +184,7 @@ class DecodeAttentionTest(PallasBaseTest):
return_residuals,
):
del kwargs
normalize_output = not return_residuals
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
@ -199,6 +203,7 @@ class DecodeAttentionTest(PallasBaseTest):
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
normalize_output=normalize_output,
interpret=self.INTERPRET,
)
o_ref, *res_ref = decode_attention.gqa_reference(
@ -208,6 +213,7 @@ class DecodeAttentionTest(PallasBaseTest):
start_idx=start_idx,
kv_seq_len=kv_seq_len,
return_residuals=return_residuals,
normalize_output=normalize_output
)
np.testing.assert_allclose(o, o_ref, atol=0.05)
if return_residuals: