mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types anymore.
PiperOrigin-RevId: 726883573
This commit is contained in:
parent
3162cc4d0d
commit
49ad24152c
@ -71,6 +71,7 @@ def attention(q, k, v, config: TuningConfig):
|
||||
|
||||
def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
|
||||
batch = lax.axis_index("batch")
|
||||
q_head = lax.axis_index("heads")
|
||||
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
|
||||
wg_idx = lax.axis_index("wg")
|
||||
qo_smem2, k_smem, v_smem = smem_buffers
|
||||
@ -85,7 +86,6 @@ def attention(q, k, v, config: TuningConfig):
|
||||
plgpu.set_max_registers(232, action="increase")
|
||||
qo_smem = qo_smem2.at[wg_idx]
|
||||
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
|
||||
q_head = lax.axis_index("heads")
|
||||
|
||||
plgpu.copy_gmem_to_smem(
|
||||
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
|
||||
@ -175,7 +175,7 @@ def attention(q, k, v, config: TuningConfig):
|
||||
@pl.when(wg_idx == 2)
|
||||
def _memory_wg():
|
||||
plgpu.set_max_registers(40, action="decrease")
|
||||
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
|
||||
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
|
||||
for i in range(max_concurrent_steps):
|
||||
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
|
||||
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
|
||||
@ -268,11 +268,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig):
|
||||
|
||||
def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped):
|
||||
batch = lax.axis_index("batch")
|
||||
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
|
||||
wg_idx = lax.axis_index("wg")
|
||||
qo_smem2, q_barriers, schedule_barrier = scoped
|
||||
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
|
||||
q_head = lax.axis_index("heads")
|
||||
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
|
||||
|
||||
def perform_schedule_barrier():
|
||||
if config.use_schedule_barrier:
|
||||
|
@ -579,7 +579,10 @@ jax_multiplatform_test(
|
||||
name = "mgpu_attention_test",
|
||||
srcs = ["mgpu_attention_test.py"],
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100_x32"],
|
||||
enable_configs = [
|
||||
"gpu_h100_x32",
|
||||
"gpu_h100",
|
||||
],
|
||||
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
|
@ -52,21 +52,38 @@ class FlashAttentionTestCase(jtu.JaxTestCase):
|
||||
batch_size=(1, 4),
|
||||
q_seq_len=(4096,),
|
||||
kv_seq_len=(4096,),
|
||||
num_q_and_kv_heads=((4, 1), # MQA
|
||||
(6, 3), # GQA
|
||||
(4, 4),), # MHA
|
||||
num_q_and_kv_heads=(
|
||||
(4, 1), # MQA
|
||||
(6, 3), # GQA
|
||||
(4, 4),
|
||||
), # MHA
|
||||
head_dim=(64, 128, 256),
|
||||
attention_impl=(
|
||||
attention_mgpu.attention,
|
||||
attention_mgpu.attention_with_pipeline_emitter,
|
||||
),
|
||||
)
|
||||
def test_flash_attention(
|
||||
self, batch_size, q_seq_len, kv_seq_len, num_q_and_kv_heads, head_dim
|
||||
self,
|
||||
batch_size,
|
||||
q_seq_len,
|
||||
kv_seq_len,
|
||||
num_q_and_kv_heads,
|
||||
head_dim,
|
||||
attention_impl,
|
||||
):
|
||||
num_q_heads, num_kv_heads = num_q_and_kv_heads
|
||||
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
|
||||
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
|
||||
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
|
||||
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
|
||||
out = attention_mgpu.attention(
|
||||
q, k, v, attention_mgpu.TuningConfig(block_q=64, block_kv=64, max_concurrent_steps=2)
|
||||
out = attention_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mgpu.TuningConfig(
|
||||
block_q=64, block_kv=64, max_concurrent_steps=2
|
||||
),
|
||||
)
|
||||
out_ref = attention_mgpu.attention_reference(q, k, v)
|
||||
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user