Merge pull request #20917 from jakevdp:flash-fix

PiperOrigin-RevId: 627764834
This commit is contained in:
jax authors 2024-04-24 10:04:20 -07:00
commit 66190d10e7

View File

@ -910,8 +910,8 @@ def _flash_attention_dkv_kernel(
@pl.when(q_seq_index == q_seq_len // block_q_major - 1)
def end_of_q_sequence():
dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref)
dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref)
dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref.dtype)
dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref.dtype)
def _flash_attention_bwd_dkv(
@ -1266,7 +1266,7 @@ def _flash_attention_dq_kernel(
@pl.when(kv_seq_index == kv_seq_len // block_k_major - 1)
def end_of_kv_sequence():
dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref)
dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref.dtype)
dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref)