mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26134 from justinjfu:pallas_accum_bugfix
PiperOrigin-RevId: 720374819
This commit is contained in:
commit
24987a90dc
@ -394,7 +394,8 @@ def paged_attention_reference(
|
||||
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
|
||||
|
||||
uncapped_logits = jnp.einsum(
|
||||
"bkgd,bksd->bkgs", q_reshaped, k_transposed
|
||||
"bkgd,bksd->bkgs", q_reshaped, k_transposed,
|
||||
preferred_element_type=jnp.float32
|
||||
).astype(jnp.float32)
|
||||
|
||||
if attn_logits_soft_cap is not None:
|
||||
@ -410,7 +411,8 @@ def paged_attention_reference(
|
||||
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
o = jnp.einsum(
|
||||
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32)
|
||||
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32),
|
||||
preferred_element_type=jnp.float32
|
||||
).astype(q.dtype)
|
||||
o = o.reshape(q.shape)
|
||||
|
||||
|
@ -1652,7 +1652,10 @@ class OpsTest(PallasBaseTest):
|
||||
x = random.normal(k1, lhs_shape, dtype=dtype)
|
||||
y = random.normal(k2, rhs_shape, dtype=dtype)
|
||||
out = dot(x, y)
|
||||
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
|
||||
# Pallas always accumulates in FP32, so we are explicit about
|
||||
# preferred_element_type here.
|
||||
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y,
|
||||
preferred_element_type=jnp.float32).astype(dtype)
|
||||
np.testing.assert_allclose(
|
||||
out.astype(jnp.float32),
|
||||
expected.astype(jnp.float32),
|
||||
|
@ -553,8 +553,10 @@ class PallasCallTest(PallasBaseTest):
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (m, k), dtype=dtype)
|
||||
y = random.normal(k2, (k, n), dtype=dtype)
|
||||
out, expected = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
|
||||
interpret=self.INTERPRET), jnp.matmul(x, y)
|
||||
out = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
|
||||
interpret=self.INTERPRET)
|
||||
expected = jnp.matmul(
|
||||
x, y, preferred_element_type=jnp.float32).astype(dtype)
|
||||
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
|
||||
|
||||
@parameterized.named_parameters(*[
|
||||
@ -576,8 +578,10 @@ class PallasCallTest(PallasBaseTest):
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (m, k), dtype=dtype)
|
||||
y = random.normal(k2, (k, n), dtype=dtype)
|
||||
out, expected = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
|
||||
interpret=self.INTERPRET), jnp.matmul(x, y)
|
||||
out = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
|
||||
interpret=self.INTERPRET)
|
||||
expected = jnp.matmul(
|
||||
x, y, preferred_element_type=jnp.float32).astype(dtype)
|
||||
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
|
||||
|
||||
@parameterized.named_parameters(*(
|
||||
|
Loading…
x
Reference in New Issue
Block a user