Merge pull request #25642 from Rifur13:numerical_stability

PiperOrigin-RevId: 716714783
This commit is contained in:
jax authors 2025-01-17 10:19:36 -08:00
commit 4d20052f7a
3 changed files with 13 additions and 9 deletions

View File

@ -598,7 +598,9 @@ def mha_reference(
):
q_seq_len = q.shape[1]
kv_seq_len = k.shape[1]
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
logits = jnp.einsum(
'bqhc,bkhc->bhqk', q, k, preferred_element_type=jnp.float32
)
mask = None
if segment_ids is not None:
mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1)
@ -608,5 +610,7 @@ def mha_reference(
causal_mask = jnp.broadcast_to(causal_mask, logits.shape)
mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
logits = logits if mask is None else jnp.where(mask, logits, float("-inf"))
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
weights = jax.nn.softmax(logits * sm_scale)
return jnp.einsum(
'bhqk,bkhc->bqhc', weights, v, preferred_element_type=jnp.float32
)

View File

@ -482,7 +482,7 @@ jax_multiplatform_test(
"gpu_a100_x32",
"gpu_h100_x32",
],
shard_count = 6,
shard_count = 15,
deps = [
"//jax:pallas",
"//jax:pallas_gpu",

View File

@ -224,8 +224,8 @@ class FusedAttentionTest(PallasBaseTest):
@jtu.sample_product(
batch_size=(1, 2),
seq_len=(128, 384),
num_heads=(1, 2, 4),
head_dim=(32,),
num_heads=(1, 2),
head_dim=(32, 64, 128,),
causal=(True, False),
use_segment_ids=(True, False),
)
@ -266,9 +266,9 @@ class FusedAttentionTest(PallasBaseTest):
dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v)
dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v)
# TODO(sharadmv): Fix test.
np.testing.assert_allclose(dq, dq_ref, atol=0.14)
np.testing.assert_allclose(dk, dk_ref, atol=0.14)
np.testing.assert_allclose(dv, dv_ref, atol=0.05)
self.assertAllClose(dq, dq_ref, atol=5e-2)
self.assertAllClose(dk, dk_ref, atol=5e-2)
self.assertAllClose(dv, dv_ref, atol=1e-3)
class FusedAttentionInterpretTest(FusedAttentionTest):