Support GQA and MQA

This commit is contained in:
kaixih 2024-07-15 22:07:08 +00:00
parent 0d5dae09ff
commit cf5bcc7ad8
2 changed files with 95 additions and 43 deletions

View File

@ -784,16 +784,10 @@ def _get_large_negative(dtype):
def _get_causal_mask(T, S, dtype):
pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype))
return mask[jnp.newaxis, jnp.newaxis, :, :]
return mask
def _dot_product_attention_xla(
query: Array,
key: Array,
value: Array,
bias: Array | None,
mask: Array | None,
is_causal: bool,
scale: float):
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
@ -811,8 +805,9 @@ def _dot_product_attention_xla(
padded_logits = logits
if is_causal:
T, S = query.shape[-3], key.shape[-3]
mask = _get_causal_mask(T, S, logits.dtype)
T, S = query.shape[1], key.shape[1]
mask = jnp.broadcast_to(_get_causal_mask(T, S, logits.dtype),
padded_logits.shape)
padded_logits = padded_logits + mask
# Softmax and it is always carried out in fp32.
@ -822,6 +817,38 @@ def _dot_product_attention_xla(
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
return encoded
def _dot_product_attention_xla(
query: Array,
key: Array,
value: Array,
bias: Array | None,
mask: Array | None,
is_causal: bool,
scale: float):
B, T, N, H = query.shape
_, S, K, _ = key.shape
G = N // K
query = jnp.reshape(query, (B, T, K, G, H))
def _reshape_to_grouped(t):
if t is not None:
tB, tN, tT, tS = t.shape
if tN == 1:
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
else:
assert tN == N
t = jnp.reshape(t, (tB, K, G, tT, tS))
return t
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
vmapped_fn = jax.vmap(_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None),
out_axes=3)
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded
def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
@ -844,24 +871,31 @@ def dot_product_attention(
:code:`probs` as the output of :math:`softmax`.
Throughout this function, we utilize the following uppercase letters to
represent the shape of array:
represent the shape of array::
B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head
K = number of key/value heads
G = number of groups, which equals to N // K
Args:
query: query array; shape :code:`(BTNH)`
key: key array; shape :code:`(BSNH)`
value: value array; shape :code:`(BSNH)`
bias: optional, bias array to be added to logits; shape broadcastable to
:code:`(BNTS)`.
key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed
attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise,
grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed
if `N` is a multiple of `K`, and multi-query attention (MQA:
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case
of GQA).
value: value array, should have the same shape as the `key` array.
bias: optional, bias array to be added to logits; The shape must be 4D and
be broadcastable to :code:`(BNTS)`.
mask: optional, mask array used to filter out logits. It is a boolean mask
where `True` indicates the element should take part in attention. For an
additive mask, users should pass it to `bias`. The shape is broadcastable
to :code:`(BNTS)`.
additive mask, users should pass it to `bias`. The shape must be 4D and be
broadcastable to :code:`(BNTS)`.
scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query's head dimension (i.e. H).
is_causal: If true, causal attention will be applied. Note, some
@ -891,15 +925,22 @@ def dot_product_attention(
bias = bias if bias is None else jnp.asarray(bias)
mask = mask if mask is None else jnp.asarray(mask)
B, S, N, H = key.shape
_check_has_shape(value, [B, S, N, H], 'value')
_check_has_shape(query, [B, -1, N, H], 'query')
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
B, S, K, H = key.shape
_check_has_shape(value, [B, S, K, H], 'value')
_check_has_shape(query, [B, -1, -1, H], 'query')
if query.shape[-2] % K != 0:
raise ValueError(f"The number of query heads must to a multiple of "
f"key/value heads, but got {query.shape[-2]} vs {K}")
if not (query.dtype == key.dtype == value.dtype):
raise ValueError(f"query/key/value should have the same dtype, but got "
f"{query.dtype} vs {key.dtype} vs {value.dtype}.")
if mask is not None and mask.dtype != jnp.bool_:
raise ValueError(f"Mask must be boolean dtype, but got {mask.dtype}.")
raise ValueError(f"query/key/value should have the same shape, but got "
f"{query.shape} vs {key.shape} vs {value.shape}.")
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
f"rank={mask.ndim}, dtype={mask.dtype}.")
if bias is not None and bias.ndim != 4:
raise ValueError(f"Bias must be a 4D tensor, but got rank={bias.ndim}.")
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
match implementation:
case 'xla':

View File

@ -54,22 +54,24 @@ def _get_causal_mask(T, S):
class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=(False, True),
causal_mode=(None, 'is_causal', 'is_mask'),
impl=('xla', 'cudnn'),
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
group_num=[1, 2, 4],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, impl):
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
group_num, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
sdpa = nn.dot_product_attention
B, S, T, N, H = 2, 128, 128, 4, 32
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
keys = random.split(random.PRNGKey(0), 4)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N, H), dtype)
V = random.normal(keys[2], (B, S, N, H), dtype)
K = random.normal(keys[1], (B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, S, N // G, H), dtype)
if use_bias:
bias = random.normal(keys[3], (1, N, T, S), dtype)
else:
@ -86,7 +88,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn('__cudnn$fmha', hlo)
out_ref = sdpa_ref(Q, K, V, bias=bias, mask=causal_mask)
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask)
out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
@ -94,20 +99,22 @@ class NNFunctionsTest(jtu.JaxTestCase):
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
group_num=[1, 2, 4],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, impl):
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
group_num, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
sdpa = nn.dot_product_attention
B, S, T, N, H = 2, 128, 128, 4, 32
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
keys = random.split(random.PRNGKey(0), 5)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N, H), dtype)
V = random.normal(keys[2], (B, S, N, H), dtype)
K = random.normal(keys[1], (B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, S, N // G, H), dtype)
grad = random.normal(keys[3], (B, T, N, H), dtype)
if use_bias:
bias = random.normal(keys[4], (1, N, T, S), dtype)
@ -117,10 +124,15 @@ class NNFunctionsTest(jtu.JaxTestCase):
is_causal = causal_mode == 'is_causal'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K, V, bias, causal_mask)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask)
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad)
if G != 1:
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m)
@ -132,10 +144,9 @@ class NNFunctionsTest(jtu.JaxTestCase):
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(')
rtol, atol = (.01, .01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=rtol, atol=atol)
self.assertAllClose(dK_ref, dK_ans, rtol=rtol, atol=atol)
self.assertAllClose(dV_ref, dV_ans, rtol=rtol, atol=atol)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)
@jtu.skip_on_flag("jax_skip_slow_tests", True)