rocm_jax/jax/experimental/pallas/ops/gpu/decode_attention.py
Sergei Lebedev 95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00

336 lines
10 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing decode attention."""
from __future__ import annotations
import functools
from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp
def attn_forward_kernel(
q_ref, # [num_heads, head_dim]
k_ref, # [k_seq_len, head_dim]
v_ref, # [k_seq_len, head_dim]
o_ref: Any, # [num_heads, head_dim]
*residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,]
sm_scale: float,
block_k: int,
):
block_h, head_dim = q_ref.shape
k_seq_len, _ = k_ref.shape
start_q = pl.program_id(0)
# o is the buffer where we accumulate the output on sram.
# m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop.
m_i = jnp.zeros(block_h, dtype=jnp.float32) - float("inf")
l_i = jnp.zeros(block_h, dtype=jnp.float32)
o = jnp.zeros((block_h, head_dim), dtype=jnp.float32)
# Load q: it will stay in L1 throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_h, head_dim].
curr_q_slice = pl.dslice(start_q * block_h, block_h)
q = pl.load(q_ref, (curr_q_slice, pl.dslice(None)))
def _dot(a, b):
# if a.shape[0] == 1:
# # Use matrix vector product
# return (a.T * b).sum(axis=0, keepdims=True)
return pl.dot(a, b)
# Loop over blocks of kv to process entire kv seq_len.
# Grid loops over q blocks over num_heads.
def body(start_k, carry):
o_prev, m_prev, l_prev = carry
curr_k_slice = pl.dslice(start_k * block_k, block_k)
k = pl.load(k_ref, (curr_k_slice, slice(None)))
qk = _dot(q, k.T) # [block_h, block_k]
if sm_scale != 1.0:
qk *= sm_scale # [block_h, block_k]
m_curr = qk.max(axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
correction = jnp.exp(m_prev - m_next)
l_prev_corr = correction * l_prev
s_curr = jnp.exp(
qk - m_next[:, None]
) # Use m_next instead of m_curr to avoid a correction on l_curr
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
v = pl.load(v_ref, (curr_k_slice, slice(None)))
o_curr = _dot(s_curr.astype(v.dtype), v)
# flash2 unscaled_o
o_next = correction[:, None] * o_prev + o_curr
return o_next, m_next, l_next
upper_bound = pl.cdiv(k_seq_len, block_k)
# o is left unscaled; it will be scaled in the final reduction step
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
if residual_refs:
l_ref, m_ref = residual_refs
pl.store(l_ref, (curr_q_slice,), l_i)
pl.store(m_ref, (curr_q_slice,), m_i)
# Write output to dram.
o = o.astype(o_ref.dtype)
pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o)
def attn_unbatched(
q, # [num_heads, head_dim]
k, # [k_seq_len, head_dim]
v, # [k_seq_len, head_dim]
sm_scale: float,
block_h: int,
block_k: int,
k_splits: int,
num_warps: int | None,
num_stages: int,
grid: tuple[int, ...] | None,
interpret: bool,
debug: bool,
):
num_heads, head_dim = q.shape
k_seq_len, _ = k.shape
# Pad num query heads to 16 if needed, and slice output at the end.
original_num_heads = None
if num_heads < 16:
q = jnp.pad(q, ((0, 16 - num_heads), (0, 0)))
original_num_heads = num_heads
num_heads = q.shape[0]
block_h = min(block_h, num_heads)
head_splits = pl.cdiv(num_heads, block_h)
grid_ = grid
if grid_ is None:
grid_ = (head_splits, k_splits)
assert (
k_seq_len % k_splits == 0
), f"{k_seq_len=} must be divisible by {k_splits=}"
k = k.reshape(k_splits, k_seq_len // k_splits, head_dim)
v = v.reshape(k_splits, k_seq_len // k_splits, head_dim)
k_seq_len = k_seq_len // k_splits
assert min(num_heads, head_dim, k_seq_len) >= 16, "Minimum pl.dot size is 16"
block_k = min(block_k, k_seq_len)
num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4
kernel = functools.partial(
attn_forward_kernel,
sm_scale=sm_scale,
block_k=block_k,
)
o, l, m = pl.pallas_call(
kernel,
grid=grid_,
in_specs=[
pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)),
pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)),
pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)),
],
out_specs=[
pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
],
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps_, num_stages=num_stages
),
out_shape=[
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
jax.ShapeDtypeStruct(
shape=(k_splits, num_heads), dtype=jnp.float32
), # l
jax.ShapeDtypeStruct(
shape=(k_splits, num_heads), dtype=jnp.float32
), # m
],
debug=debug,
interpret=interpret,
name="mha_forward",
)(q, k, v)
# final round of flash
m_next = m.max(axis=0)
correction = jnp.exp(m - m_next[None])
o = o * correction[:, :, None]
l_next = (l * correction).sum(axis=0)
o = o.sum(axis=0) / l_next[:, None]
if original_num_heads is not None:
o = o[:original_num_heads, :]
return o
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"block_h",
"block_k",
"k_splits",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
],
)
def mqa(
q, # [batch_size, num_heads, head_dim]
k, # [batch_size, k_seq_len, head_dim]
v, # [batch_size, k_seq_len, head_dim]
sm_scale: float = 1.0,
block_h: int = 16,
block_k: int = 256,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
):
inner = functools.partial(
attn_unbatched,
sm_scale=sm_scale,
block_h=block_h,
block_k=block_k,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
grid=grid,
interpret=interpret,
debug=debug,
)
return jax.vmap(inner)(q, k, v)
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"block_h",
"block_k",
"k_splits",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
],
)
def gqa(
q, # [batch_size, num_q_heads, head_dim]
k, # [batch_size, k_seq_len, num_kv_heads, head_dim]
v, # [batch_size, k_seq_len, num_kv_heads, head_dim]
sm_scale: float = 1.0,
block_h: int = 16,
block_k: int = 256,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
):
batch_size, q_heads, head_dim = q.shape
kv_heads = k.shape[2]
assert kv_heads == v.shape[2]
assert q_heads % kv_heads == 0
q_heads_per_kv_head = q_heads // kv_heads
q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
inner = functools.partial(
attn_unbatched,
sm_scale=sm_scale,
block_h=block_h,
block_k=block_k,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
grid=grid,
interpret=interpret,
debug=debug,
)
with_kv_heads = jax.vmap(inner)
o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed)
return o.reshape(batch_size, q_heads, head_dim)
@functools.partial(jax.jit, static_argnames=["sm_scale"])
def mqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, head_dim]
v, # [bs, k_seq_len, head_dim]
sm_scale=1.0,
):
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum("bns,bsd->bnd", weights, v)
@functools.partial(jax.jit, static_argnames=["sm_scale"])
def mha_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
v, # [bs, k_seq_len, num_v_heads, head_dim]
sm_scale=1.0,
):
assert q.shape[1] == k.shape[2]
logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum("bns,bsnd->bnd", weights, v)
@functools.partial(jax.jit, static_argnames=["sm_scale"])
def gqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
v, # [bs, k_seq_len, num_v_heads, head_dim]
sm_scale=1.0,
):
bs, num_q_heads, head_dim = q.shape
num_kv_heads = k.shape[2]
assert num_q_heads % num_kv_heads == 0
q_reshaped = q.reshape(
bs, num_kv_heads, num_q_heads // num_kv_heads, head_dim
)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
jnp.float32
)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed)
return o.reshape(bs, num_q_heads, head_dim)