fix dbias in bwd_batcher

This commit is contained in:
kaixih 2024-09-18 21:23:16 +00:00
parent d0cb3182aa
commit b7e26ba3ee
2 changed files with 127 additions and 0 deletions

View File

@ -19,6 +19,7 @@ from __future__ import annotations
from collections.abc import Sequence
from functools import partial
import operator
import math
import numpy as np
from typing import Any, Literal
import warnings
@ -34,6 +35,8 @@ from jax._src import util
from jax._src.core import AxisName
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike, DType
from jax._src.ops.special import logsumexp as _logsumexp
@ -900,6 +903,68 @@ def _dot_product_attention_xla(
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded
def bias_fwd_rule(a, query_head_num):
return bias_fwd_p.bind(a, query_head_num), a
def bias_bwd_rule(query_head_num, res, g):
a = res
if a.shape[0] > 1 or a.shape[-3] != query_head_num:
raise ValueError("cuDNN only supports bias gradient when the batch size is "
f"1 and the head number matches the query, but got "
f"B={a.shape[0]}, N={a.shape[-3]}.")
return (bias_bwd_p.bind(g, a, query_head_num),)
# This function uses two custom primitives, `bias_fwd` and `bias_bwd`, to work
# around a cuDNN issue where bias gradients are only supported when the batch
# size is 1 and the number of heads matches the query.
# TODO(kaixih@nvidia): Remove this workaround once cuDNN resolves the issue.
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def check_valid_bias_batch(x, query_head_num):
output, _ = bias_fwd_rule(x, query_head_num)
return output
check_valid_bias_batch.defvjp(bias_fwd_rule, bias_bwd_rule)
bias_fwd_p = core.Primitive('bias_fwd')
bias_fwd_p.multiple_results = False
bias_bwd_p = core.Primitive('bias_bwd')
bias_bwd_p.multiple_results = False
def bias_fwd_impl(a, query_head_num):
return a
def bias_bwd_impl(g, a, query_head_num):
return g
bias_fwd_p.def_impl(bias_fwd_impl)
bias_bwd_p.def_impl(bias_bwd_impl)
def bias_fwd_abstract_eval(a, query_head_num):
return core.ShapedArray(a.shape, a.dtype)
def bias_bwd_abstract_eval(g, a, query_head_num):
return core.ShapedArray(g.shape, g.dtype)
bias_fwd_p.def_abstract_eval(bias_fwd_abstract_eval)
bias_bwd_p.def_abstract_eval(bias_bwd_abstract_eval)
def bias_fwd_lowering(ctx, a, query_head_num):
return [a]
def bias_bwd_lowering(ctx, g, a, query_head_num):
return [g]
mlir.register_lowering(bias_fwd_p, bias_fwd_lowering)
mlir.register_lowering(bias_bwd_p, bias_bwd_lowering)
def bias_fwd_batch_rule(batched_args, batch_dims):
x, query_head_num = batched_args
a = batch_dims[0]
output, _ = bias_fwd_rule(x, query_head_num)
return output, a
def bias_bwd_batch_rule(batched_args, batch_dims):
g, x, query_head_num = batched_args
b = batch_dims[0]
*Bs, _, _, _ = x.shape
B = math.prod(Bs)
x = jnp.reshape(x, (B,) + x.shape[-3:])
output, = bias_bwd_rule(query_head_num, x, g)
return output, b
batching.primitive_batchers[bias_fwd_p] = bias_fwd_batch_rule
batching.primitive_batchers[bias_bwd_p] = bias_bwd_batch_rule
def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
@ -1032,6 +1097,9 @@ def dot_product_attention(
local_window_size=local_window_size,
)
case 'cudnn':
if bias is not None:
bias = check_valid_bias_batch(bias, query_arr.shape[-2])
bias = jnp.asarray(bias)
use_padding = (
query_seq_lengths is not None or key_value_seq_lengths is not None
)

View File

@ -50,6 +50,8 @@ def _check_cudnn_backend(fn, *args, **kwargs):
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
return '__cudnn$fmha' in hlo
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@ -167,6 +169,63 @@ class NNFunctionsTest(jtu.JaxTestCase):
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)
@parameterized.product(
batch_size=[1, 16],
use_vmap=[False, True],
)
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
if not _is_required_cudnn_version_satisfied(8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16
B, S, N, H = batch_size, 128, 4, 32
keys = random.split(random.PRNGKey(0), 2)
x = random.normal(keys[0], (B, S, N, H), dtype)
bias = random.normal(keys[1], (B, N, S, S), dtype=dtype)
mask = jnp.ones((1, 1, S), dtype=jnp.bool_)
def attention(x, bias, mask, impl):
return jax.nn.dot_product_attention(
query=x,
key=x,
value=x,
bias=bias,
mask=mask,
is_causal=False,
implementation=impl,
)
attn_ref = partial(attention, impl=None)
attn_ans = partial(attention, impl='cudnn')
if use_vmap:
attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None))
attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None))
else:
attn_batched_ref = attn_ref
attn_batched_ans = attn_ans
fwd_ref = jax.jit(attn_batched_ref)
fwd_ans = jax.jit(attn_batched_ans)
y_ref = fwd_ref(x, bias, mask)
y_ans = fwd_ans(x, bias, mask)
self.assertAllClose(y_ref, y_ans)
@jax.jit
def bwd_ref(x, bias, mask):
_, f_vjp = jax.vjp(attn_ref, x, bias, mask)
return f_vjp(x)
@jax.jit
def bwd_ans(x, bias, mask):
_, f_vjp = jax.vjp(attn_ans, x, bias, mask)
return f_vjp(x)
if batch_size != 1:
with self.assertRaisesRegex(ValueError, _cudnn_dbias_error):
_, dbias_ans, _ = bwd_ans(x, bias, mask)
else:
_, dbias_ref, _ = bwd_ref(x, bias, mask)
_, dbias_ans, _ = bwd_ans(x, bias, mask)
self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), order=4,