mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #25642 from Rifur13:numerical_stability
PiperOrigin-RevId: 716714783
This commit is contained in:
commit
4d20052f7a
@ -598,7 +598,9 @@ def mha_reference(
|
||||
):
|
||||
q_seq_len = q.shape[1]
|
||||
kv_seq_len = k.shape[1]
|
||||
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
|
||||
logits = jnp.einsum(
|
||||
'bqhc,bkhc->bhqk', q, k, preferred_element_type=jnp.float32
|
||||
)
|
||||
mask = None
|
||||
if segment_ids is not None:
|
||||
mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1)
|
||||
@ -608,5 +610,7 @@ def mha_reference(
|
||||
causal_mask = jnp.broadcast_to(causal_mask, logits.shape)
|
||||
mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
||||
logits = logits if mask is None else jnp.where(mask, logits, float("-inf"))
|
||||
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
|
||||
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
|
||||
weights = jax.nn.softmax(logits * sm_scale)
|
||||
return jnp.einsum(
|
||||
'bhqk,bkhc->bqhc', weights, v, preferred_element_type=jnp.float32
|
||||
)
|
||||
|
@ -482,7 +482,7 @@ jax_multiplatform_test(
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
],
|
||||
shard_count = 6,
|
||||
shard_count = 15,
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_gpu",
|
||||
|
@ -224,8 +224,8 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
@jtu.sample_product(
|
||||
batch_size=(1, 2),
|
||||
seq_len=(128, 384),
|
||||
num_heads=(1, 2, 4),
|
||||
head_dim=(32,),
|
||||
num_heads=(1, 2),
|
||||
head_dim=(32, 64, 128,),
|
||||
causal=(True, False),
|
||||
use_segment_ids=(True, False),
|
||||
)
|
||||
@ -266,9 +266,9 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v)
|
||||
dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v)
|
||||
# TODO(sharadmv): Fix test.
|
||||
np.testing.assert_allclose(dq, dq_ref, atol=0.14)
|
||||
np.testing.assert_allclose(dk, dk_ref, atol=0.14)
|
||||
np.testing.assert_allclose(dv, dv_ref, atol=0.05)
|
||||
self.assertAllClose(dq, dq_ref, atol=5e-2)
|
||||
self.assertAllClose(dk, dk_ref, atol=5e-2)
|
||||
self.assertAllClose(dv, dv_ref, atol=1e-3)
|
||||
|
||||
|
||||
class FusedAttentionInterpretTest(FusedAttentionTest):
|
||||
|
Loading…
x
Reference in New Issue
Block a user