mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support GQA and MQA
This commit is contained in:
parent
0d5dae09ff
commit
cf5bcc7ad8
@ -784,16 +784,10 @@ def _get_large_negative(dtype):
|
||||
def _get_causal_mask(T, S, dtype):
|
||||
pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
||||
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype))
|
||||
return mask[jnp.newaxis, jnp.newaxis, :, :]
|
||||
return mask
|
||||
|
||||
def _dot_product_attention_xla(
|
||||
query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array | None,
|
||||
mask: Array | None,
|
||||
is_causal: bool,
|
||||
scale: float):
|
||||
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
|
||||
scale):
|
||||
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
|
||||
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
|
||||
preferred_element_type=logits_dtype)
|
||||
@ -811,8 +805,9 @@ def _dot_product_attention_xla(
|
||||
padded_logits = logits
|
||||
|
||||
if is_causal:
|
||||
T, S = query.shape[-3], key.shape[-3]
|
||||
mask = _get_causal_mask(T, S, logits.dtype)
|
||||
T, S = query.shape[1], key.shape[1]
|
||||
mask = jnp.broadcast_to(_get_causal_mask(T, S, logits.dtype),
|
||||
padded_logits.shape)
|
||||
padded_logits = padded_logits + mask
|
||||
|
||||
# Softmax and it is always carried out in fp32.
|
||||
@ -822,6 +817,38 @@ def _dot_product_attention_xla(
|
||||
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
|
||||
return encoded
|
||||
|
||||
def _dot_product_attention_xla(
|
||||
query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array | None,
|
||||
mask: Array | None,
|
||||
is_causal: bool,
|
||||
scale: float):
|
||||
|
||||
B, T, N, H = query.shape
|
||||
_, S, K, _ = key.shape
|
||||
G = N // K
|
||||
|
||||
query = jnp.reshape(query, (B, T, K, G, H))
|
||||
def _reshape_to_grouped(t):
|
||||
if t is not None:
|
||||
tB, tN, tT, tS = t.shape
|
||||
if tN == 1:
|
||||
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
|
||||
else:
|
||||
assert tN == N
|
||||
t = jnp.reshape(t, (tB, K, G, tT, tS))
|
||||
return t
|
||||
bias = _reshape_to_grouped(bias)
|
||||
mask = _reshape_to_grouped(mask)
|
||||
vmapped_fn = jax.vmap(_dot_product_attention_core,
|
||||
in_axes=(3, None, None, 2, 2, None, None),
|
||||
out_axes=3)
|
||||
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
|
||||
encoded = jnp.reshape(encoded, (B, T, N, H))
|
||||
return encoded
|
||||
|
||||
def dot_product_attention(
|
||||
query: ArrayLike,
|
||||
key: ArrayLike,
|
||||
@ -844,24 +871,31 @@ def dot_product_attention(
|
||||
:code:`probs` as the output of :math:`softmax`.
|
||||
|
||||
Throughout this function, we utilize the following uppercase letters to
|
||||
represent the shape of array:
|
||||
represent the shape of array::
|
||||
|
||||
B = batch size
|
||||
S = length of the key/value (source)
|
||||
T = length of the query (target)
|
||||
N = number of attention heads
|
||||
H = dimensions of each attention head
|
||||
K = number of key/value heads
|
||||
G = number of groups, which equals to N // K
|
||||
|
||||
Args:
|
||||
query: query array; shape :code:`(BTNH)`
|
||||
key: key array; shape :code:`(BSNH)`
|
||||
value: value array; shape :code:`(BSNH)`
|
||||
bias: optional, bias array to be added to logits; shape broadcastable to
|
||||
:code:`(BNTS)`.
|
||||
key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed
|
||||
attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise,
|
||||
grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed
|
||||
if `N` is a multiple of `K`, and multi-query attention (MQA:
|
||||
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case
|
||||
of GQA).
|
||||
value: value array, should have the same shape as the `key` array.
|
||||
bias: optional, bias array to be added to logits; The shape must be 4D and
|
||||
be broadcastable to :code:`(BNTS)`.
|
||||
mask: optional, mask array used to filter out logits. It is a boolean mask
|
||||
where `True` indicates the element should take part in attention. For an
|
||||
additive mask, users should pass it to `bias`. The shape is broadcastable
|
||||
to :code:`(BNTS)`.
|
||||
additive mask, users should pass it to `bias`. The shape must be 4D and be
|
||||
broadcastable to :code:`(BNTS)`.
|
||||
scale: scale for the logits. If None, the scale will be set to 1 divided by
|
||||
the square root of query's head dimension (i.e. H).
|
||||
is_causal: If true, causal attention will be applied. Note, some
|
||||
@ -891,15 +925,22 @@ def dot_product_attention(
|
||||
bias = bias if bias is None else jnp.asarray(bias)
|
||||
mask = mask if mask is None else jnp.asarray(mask)
|
||||
|
||||
B, S, N, H = key.shape
|
||||
_check_has_shape(value, [B, S, N, H], 'value')
|
||||
_check_has_shape(query, [B, -1, N, H], 'query')
|
||||
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
|
||||
B, S, K, H = key.shape
|
||||
_check_has_shape(value, [B, S, K, H], 'value')
|
||||
_check_has_shape(query, [B, -1, -1, H], 'query')
|
||||
if query.shape[-2] % K != 0:
|
||||
raise ValueError(f"The number of query heads must to a multiple of "
|
||||
f"key/value heads, but got {query.shape[-2]} vs {K}")
|
||||
if not (query.dtype == key.dtype == value.dtype):
|
||||
raise ValueError(f"query/key/value should have the same dtype, but got "
|
||||
f"{query.dtype} vs {key.dtype} vs {value.dtype}.")
|
||||
if mask is not None and mask.dtype != jnp.bool_:
|
||||
raise ValueError(f"Mask must be boolean dtype, but got {mask.dtype}.")
|
||||
raise ValueError(f"query/key/value should have the same shape, but got "
|
||||
f"{query.shape} vs {key.shape} vs {value.shape}.")
|
||||
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
|
||||
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
|
||||
f"rank={mask.ndim}, dtype={mask.dtype}.")
|
||||
if bias is not None and bias.ndim != 4:
|
||||
raise ValueError(f"Bias must be a 4D tensor, but got rank={bias.ndim}.")
|
||||
|
||||
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
|
||||
|
||||
match implementation:
|
||||
case 'xla':
|
||||
|
@ -54,22 +54,24 @@ def _get_causal_mask(T, S):
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@parameterized.product(
|
||||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
|
||||
use_bias=(False, True),
|
||||
causal_mode=(None, 'is_causal', 'is_mask'),
|
||||
impl=('xla', 'cudnn'),
|
||||
use_bias=[False, True],
|
||||
causal_mode=[None, 'is_causal', 'is_mask'],
|
||||
group_num=[1, 2, 4],
|
||||
impl=['xla', 'cudnn'],
|
||||
)
|
||||
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, impl):
|
||||
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
|
||||
group_num, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
|
||||
sdpa = nn.dot_product_attention
|
||||
B, S, T, N, H = 2, 128, 128, 4, 32
|
||||
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
|
||||
keys = random.split(random.PRNGKey(0), 4)
|
||||
Q = random.normal(keys[0], (B, T, N, H), dtype)
|
||||
K = random.normal(keys[1], (B, S, N, H), dtype)
|
||||
V = random.normal(keys[2], (B, S, N, H), dtype)
|
||||
K = random.normal(keys[1], (B, S, N // G, H), dtype)
|
||||
V = random.normal(keys[2], (B, S, N // G, H), dtype)
|
||||
if use_bias:
|
||||
bias = random.normal(keys[3], (1, N, T, S), dtype)
|
||||
else:
|
||||
@ -86,7 +88,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertIn('__cudnn$fmha', hlo)
|
||||
|
||||
out_ref = sdpa_ref(Q, K, V, bias=bias, mask=causal_mask)
|
||||
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
|
||||
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
|
||||
out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask)
|
||||
|
||||
out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask)
|
||||
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
|
||||
|
||||
@ -94,20 +99,22 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
|
||||
use_bias=[False, True],
|
||||
causal_mode=[None, 'is_causal', 'is_mask'],
|
||||
group_num=[1, 2, 4],
|
||||
impl=['xla', 'cudnn'],
|
||||
)
|
||||
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, impl):
|
||||
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
|
||||
group_num, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
|
||||
sdpa = nn.dot_product_attention
|
||||
B, S, T, N, H = 2, 128, 128, 4, 32
|
||||
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
|
||||
keys = random.split(random.PRNGKey(0), 5)
|
||||
Q = random.normal(keys[0], (B, T, N, H), dtype)
|
||||
K = random.normal(keys[1], (B, S, N, H), dtype)
|
||||
V = random.normal(keys[2], (B, S, N, H), dtype)
|
||||
K = random.normal(keys[1], (B, S, N // G, H), dtype)
|
||||
V = random.normal(keys[2], (B, S, N // G, H), dtype)
|
||||
grad = random.normal(keys[3], (B, T, N, H), dtype)
|
||||
if use_bias:
|
||||
bias = random.normal(keys[4], (1, N, T, S), dtype)
|
||||
@ -117,10 +124,15 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
is_causal = causal_mode == 'is_causal'
|
||||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
|
||||
|
||||
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
|
||||
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
|
||||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
|
||||
fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m)
|
||||
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K, V, bias, causal_mask)
|
||||
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask)
|
||||
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad)
|
||||
if G != 1:
|
||||
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
||||
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
||||
|
||||
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
|
||||
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m)
|
||||
@ -132,10 +144,9 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(')
|
||||
|
||||
rtol, atol = (.01, .01)
|
||||
self.assertAllClose(dQ_ref, dQ_ans, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(dK_ref, dK_ans, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(dV_ref, dV_ans, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
|
||||
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
|
||||
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
|
||||
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user