[pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types anymore.

PiperOrigin-RevId: 726883573
This commit is contained in:
Christos Perivolaropoulos 2025-02-14 05:10:07 -08:00 committed by jax authors
parent 3162cc4d0d
commit 49ad24152c
3 changed files with 30 additions and 10 deletions

View File

@ -71,6 +71,7 @@ def attention(q, k, v, config: TuningConfig):
def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
q_head = lax.axis_index("heads")
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
wg_idx = lax.axis_index("wg")
qo_smem2, k_smem, v_smem = smem_buffers
@ -85,7 +86,6 @@ def attention(q, k, v, config: TuningConfig):
plgpu.set_max_registers(232, action="increase")
qo_smem = qo_smem2.at[wg_idx]
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")
plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
@ -175,7 +175,7 @@ def attention(q, k, v, config: TuningConfig):
@pl.when(wg_idx == 2)
def _memory_wg():
plgpu.set_max_registers(40, action="decrease")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
for i in range(max_concurrent_steps):
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
@ -268,11 +268,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig):
def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
wg_idx = lax.axis_index("wg")
qo_smem2, q_barriers, schedule_barrier = scoped
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
def perform_schedule_barrier():
if config.use_schedule_barrier:

View File

@ -579,7 +579,10 @@ jax_multiplatform_test(
name = "mgpu_attention_test",
srcs = ["mgpu_attention_test.py"],
enable_backends = [],
enable_configs = ["gpu_h100_x32"],
enable_configs = [
"gpu_h100_x32",
"gpu_h100",
],
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
deps = [
"//jax:pallas",

View File

@ -52,21 +52,38 @@ class FlashAttentionTestCase(jtu.JaxTestCase):
batch_size=(1, 4),
q_seq_len=(4096,),
kv_seq_len=(4096,),
num_q_and_kv_heads=((4, 1), # MQA
(6, 3), # GQA
(4, 4),), # MHA
num_q_and_kv_heads=(
(4, 1), # MQA
(6, 3), # GQA
(4, 4),
), # MHA
head_dim=(64, 128, 256),
attention_impl=(
attention_mgpu.attention,
attention_mgpu.attention_with_pipeline_emitter,
),
)
def test_flash_attention(
self, batch_size, q_seq_len, kv_seq_len, num_q_and_kv_heads, head_dim
self,
batch_size,
q_seq_len,
kv_seq_len,
num_q_and_kv_heads,
head_dim,
attention_impl,
):
num_q_heads, num_kv_heads = num_q_and_kv_heads
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
out = attention_mgpu.attention(
q, k, v, attention_mgpu.TuningConfig(block_q=64, block_kv=64, max_concurrent_steps=2)
out = attention_impl(
q,
k,
v,
attention_mgpu.TuningConfig(
block_q=64, block_kv=64, max_concurrent_steps=2
),
)
out_ref = attention_mgpu.attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)