mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26715 from Rifur13:normalize
PiperOrigin-RevId: 730667196
This commit is contained in:
commit
41faf51a16
@ -143,7 +143,8 @@ def decode_attn_unbatched(
|
||||
grid: tuple[int, ...] | None,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
return_residuals: bool
|
||||
return_residuals: bool,
|
||||
normalize_output: bool
|
||||
):
|
||||
num_heads, head_dim = q.shape
|
||||
k_seq_len, _ = k.shape
|
||||
@ -218,7 +219,9 @@ def decode_attn_unbatched(
|
||||
o = o * correction[:, :, None].astype(o.dtype)
|
||||
l_next = (l * correction).sum(axis=0)
|
||||
eps = jnp.finfo(l_next.dtype).eps
|
||||
o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps)
|
||||
o = o.sum(axis=0)
|
||||
if normalize_output:
|
||||
o /= (l_next[:, None].astype(o.dtype) + eps)
|
||||
if return_residuals:
|
||||
return o, (l_next, m_next)
|
||||
else:
|
||||
@ -237,7 +240,8 @@ def decode_attn_unbatched(
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals"
|
||||
"return_residuals",
|
||||
"normalize_output"
|
||||
],
|
||||
)
|
||||
def mqa(
|
||||
@ -255,7 +259,8 @@ def mqa(
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False
|
||||
return_residuals: bool = False,
|
||||
normalize_output: bool = True,
|
||||
):
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
bs = q.shape[0]
|
||||
@ -274,7 +279,8 @@ def mqa(
|
||||
grid=grid,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
return_residuals=return_residuals
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output,
|
||||
)
|
||||
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
|
||||
|
||||
@ -291,7 +297,8 @@ def mqa(
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals"
|
||||
"return_residuals",
|
||||
"normalize_output"
|
||||
],
|
||||
)
|
||||
def gqa(
|
||||
@ -310,7 +317,12 @@ def gqa(
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False,
|
||||
normalize_output: bool = True,
|
||||
):
|
||||
if not normalize_output and not return_residuals:
|
||||
raise NotImplementedError(
|
||||
"When normalize_output is False, attention residuals must be returned."
|
||||
)
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
batch_size, q_heads, head_dim = q.shape
|
||||
k_seq_len, kv_heads = k.shape[1], k.shape[2]
|
||||
@ -344,6 +356,7 @@ def gqa(
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output,
|
||||
)
|
||||
with_kv_heads = jax.vmap(inner)
|
||||
o, *res = jax.vmap(with_kv_heads)(
|
||||
@ -359,7 +372,10 @@ def gqa(
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
|
||||
)
|
||||
def mqa_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, head_dim]
|
||||
@ -367,7 +383,8 @@ def mqa_reference(
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
return_residuals=False
|
||||
return_residuals=False,
|
||||
normalize_output=True,
|
||||
):
|
||||
original_dtype = q.dtype
|
||||
q = q.astype(jnp.float32)
|
||||
@ -389,7 +406,8 @@ def mqa_reference(
|
||||
m = logits.max(axis=-1)
|
||||
s = jnp.exp(logits - m[..., None])
|
||||
l = s.sum(axis=-1)
|
||||
s = s / l[..., None]
|
||||
if normalize_output:
|
||||
s = s / l[..., None]
|
||||
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
|
||||
|
||||
if return_residuals:
|
||||
@ -423,7 +441,10 @@ def mha_reference(
|
||||
return jnp.einsum("bns,bsnd->bnd", weights, v)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
|
||||
)
|
||||
def gqa_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, num_k_heads, head_dim]
|
||||
@ -431,7 +452,8 @@ def gqa_reference(
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
return_residuals=False
|
||||
return_residuals=False,
|
||||
normalize_output=True
|
||||
):
|
||||
original_dtype = q.dtype
|
||||
q = q.astype(jnp.float32)
|
||||
@ -466,7 +488,8 @@ def gqa_reference(
|
||||
m = logits.max(axis=-1)
|
||||
s = jnp.exp(logits - m[..., None])
|
||||
l = s.sum(axis=-1)
|
||||
s = s / l[..., None]
|
||||
if normalize_output:
|
||||
s = s / l[..., None]
|
||||
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
|
||||
o = o.reshape(bs, num_q_heads, head_dim)
|
||||
|
||||
|
@ -105,6 +105,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
return_residuals,
|
||||
):
|
||||
del kwargs
|
||||
normalize_output = not return_residuals
|
||||
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16)
|
||||
@ -118,6 +119,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o_ref, *res_ref = decode_attention.mqa_reference(
|
||||
@ -127,6 +129,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output
|
||||
)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
if return_residuals:
|
||||
@ -181,6 +184,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
return_residuals,
|
||||
):
|
||||
del kwargs
|
||||
normalize_output = not return_residuals
|
||||
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
@ -199,6 +203,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o_ref, *res_ref = decode_attention.gqa_reference(
|
||||
@ -208,6 +213,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
start_idx=start_idx,
|
||||
kv_seq_len=kv_seq_len,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output
|
||||
)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
if return_residuals:
|
||||
|
Loading…
x
Reference in New Issue
Block a user