mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow head_dim <= 128 in Pallas:TPU flash attention implementation
PiperOrigin-RevId: 568275903
This commit is contained in:
parent
c7f60fa6eb
commit
ebf350b715
@ -368,26 +368,26 @@ def _flash_attention_kernel_single_batch(
|
||||
l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]
|
||||
|
||||
head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE)
|
||||
l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1)
|
||||
if rem:
|
||||
raise NotImplementedError(
|
||||
f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE}"
|
||||
)
|
||||
if head_dim_repeats == 0:
|
||||
l_broadcast = lambda l: l[:, :head_dim]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger"
|
||||
)
|
||||
l_scratch_ref[batch_idx] = l_next
|
||||
m_scratch_ref[batch_idx] = m_next
|
||||
|
||||
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
|
||||
acc_scratch_ref[batch_idx] *= pltpu.repeat(
|
||||
l_corr * l_next_inv_safe, head_dim_repeats, 1
|
||||
)
|
||||
acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe)
|
||||
v = pl.load(
|
||||
v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None))
|
||||
)
|
||||
o_curr = jax.lax.dot(
|
||||
p.astype(v.dtype), v, preferred_element_type=jnp.float32
|
||||
)
|
||||
acc_scratch_ref[batch_idx] += o_curr * pltpu.repeat(
|
||||
l_next_inv_safe, head_dim_repeats, 1
|
||||
)
|
||||
acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe)
|
||||
|
||||
@pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
|
||||
def store_output():
|
||||
|
@ -883,17 +883,11 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(op.getType().getElementTypeBitWidth() == 32,
|
||||
"Only 32-bit types supported");
|
||||
auto offsets = op.getOffsets().getValue();
|
||||
auto sizes = op.getSizes().getValue();
|
||||
auto strides = op.getStrides().getValue();
|
||||
for (auto offset_attr : offsets.take_back(2)) {
|
||||
int off = offset_attr.cast<IntegerAttr>().getInt();
|
||||
TPU_CHECK_OP(off == 0, "Only zero-offset slices supported.");
|
||||
}
|
||||
sizes = sizes.take_back(2);
|
||||
TPU_CHECK_OP(
|
||||
(sizes[0].cast<IntegerAttr>().getInt() % target_shape_[0] == 0) &&
|
||||
(sizes[1].cast<IntegerAttr>().getInt() % target_shape_[1] == 0),
|
||||
"Only lane and sublane aligned slices allowed.");
|
||||
for (auto stride : strides) {
|
||||
TPU_CHECK_OP(stride.cast<IntegerAttr>().getInt() == 1,
|
||||
"Only trivial strides supported.");
|
||||
|
Loading…
x
Reference in New Issue
Block a user