[Pallas] Be explicit about accumulation dtype in reference implementations

This commit is contained in:
Justin Fu 2025-01-27 22:09:29 +00:00
parent 6004a501ad
commit 7ace72fb3a
3 changed files with 16 additions and 7 deletions

View File

@ -394,7 +394,8 @@ def paged_attention_reference(
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
uncapped_logits = jnp.einsum(
"bkgd,bksd->bkgs", q_reshaped, k_transposed
"bkgd,bksd->bkgs", q_reshaped, k_transposed,
preferred_element_type=jnp.float32
).astype(jnp.float32)
if attn_logits_soft_cap is not None:
@ -410,7 +411,8 @@ def paged_attention_reference(
weights = jax.nn.softmax(logits, axis=-1)
o = jnp.einsum(
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32)
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32),
preferred_element_type=jnp.float32
).astype(q.dtype)
o = o.reshape(q.shape)

View File

@ -1652,7 +1652,10 @@ class OpsTest(PallasBaseTest):
x = random.normal(k1, lhs_shape, dtype=dtype)
y = random.normal(k2, rhs_shape, dtype=dtype)
out = dot(x, y)
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
# Pallas always accumulates in FP32, so we are explicit about
# preferred_element_type here.
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y,
preferred_element_type=jnp.float32).astype(dtype)
np.testing.assert_allclose(
out.astype(jnp.float32),
expected.astype(jnp.float32),

View File

@ -553,8 +553,10 @@ class PallasCallTest(PallasBaseTest):
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
out, expected = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
interpret=self.INTERPRET), jnp.matmul(x, y)
out = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
interpret=self.INTERPRET)
expected = jnp.matmul(
x, y, preferred_element_type=jnp.float32).astype(dtype)
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
@parameterized.named_parameters(*[
@ -576,8 +578,10 @@ class PallasCallTest(PallasBaseTest):
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
out, expected = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
interpret=self.INTERPRET), jnp.matmul(x, y)
out = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
interpret=self.INTERPRET)
expected = jnp.matmul(
x, y, preferred_element_type=jnp.float32).astype(dtype)
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
@parameterized.named_parameters(*(