Allow head_dim <= 128 in Pallas:TPU flash attention implementation

PiperOrigin-RevId: 568275903
This commit is contained in:
Adam Paszke 2023-09-25 11:24:36 -07:00 committed by jax authors
parent c7f60fa6eb
commit ebf350b715
2 changed files with 9 additions and 15 deletions

View File

@ -368,26 +368,26 @@ def _flash_attention_kernel_single_batch(
l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]
head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE)
l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1)
if rem:
raise NotImplementedError(
f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE}"
)
if head_dim_repeats == 0:
l_broadcast = lambda l: l[:, :head_dim]
else:
raise NotImplementedError(
f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger"
)
l_scratch_ref[batch_idx] = l_next
m_scratch_ref[batch_idx] = m_next
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
acc_scratch_ref[batch_idx] *= pltpu.repeat(
l_corr * l_next_inv_safe, head_dim_repeats, 1
)
acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe)
v = pl.load(
v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None))
)
o_curr = jax.lax.dot(
p.astype(v.dtype), v, preferred_element_type=jnp.float32
)
acc_scratch_ref[batch_idx] += o_curr * pltpu.repeat(
l_next_inv_safe, head_dim_repeats, 1
)
acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe)
@pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
def store_output():

View File

@ -883,17 +883,11 @@ class VectorLayoutInferer {
TPU_CHECK_OP(op.getType().getElementTypeBitWidth() == 32,
"Only 32-bit types supported");
auto offsets = op.getOffsets().getValue();
auto sizes = op.getSizes().getValue();
auto strides = op.getStrides().getValue();
for (auto offset_attr : offsets.take_back(2)) {
int off = offset_attr.cast<IntegerAttr>().getInt();
TPU_CHECK_OP(off == 0, "Only zero-offset slices supported.");
}
sizes = sizes.take_back(2);
TPU_CHECK_OP(
(sizes[0].cast<IntegerAttr>().getInt() % target_shape_[0] == 0) &&
(sizes[1].cast<IntegerAttr>().getInt() % target_shape_[1] == 0),
"Only lane and sublane aligned slices allowed.");
for (auto stride : strides) {
TPU_CHECK_OP(stride.cast<IntegerAttr>().getInt() == 1,
"Only trivial strides supported.");