mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #20917 from jakevdp:flash-fix
PiperOrigin-RevId: 627764834
This commit is contained in:
commit
66190d10e7
@ -910,8 +910,8 @@ def _flash_attention_dkv_kernel(
|
||||
|
||||
@pl.when(q_seq_index == q_seq_len // block_q_major - 1)
|
||||
def end_of_q_sequence():
|
||||
dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref)
|
||||
dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref)
|
||||
dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref.dtype)
|
||||
dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref.dtype)
|
||||
|
||||
|
||||
def _flash_attention_bwd_dkv(
|
||||
@ -1266,7 +1266,7 @@ def _flash_attention_dq_kernel(
|
||||
|
||||
@pl.when(kv_seq_index == kv_seq_len // block_k_major - 1)
|
||||
def end_of_kv_sequence():
|
||||
dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref)
|
||||
dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref.dtype)
|
||||
dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user