PR #22330: [NVIDIA] Remove logic of combining bias and mask

Imported from GitHub PR https://github.com/google/jax/pull/22330

The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078). We are removing the logic from the public sdpa API and pass the mask directly.

cc. @Cjkkkk
Copybara import of the project:

--
0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>:

Remove logic of combining bias and mask

Merging this change closes #22330

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a
PiperOrigin-RevId: 652830016
This commit is contained in:
Kaixi Hou 2024-07-16 07:18:19 -07:00 committed by jax authors
parent 7a62b8dd18
commit 09531d2ff8

View File

@ -908,26 +908,8 @@ def dot_product_attention(
)
case 'cudnn':
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
# Convert bool mask to float mask for addition
if mask is not None:
large_negative_number = _get_large_negative(query.dtype)
mask = jnp.where(mask, jnp.zeros((), query.dtype),
large_negative_number)
# Prepare the bias for cudnn flash attention:
# We should never use the mask argument of cudnn, because it is
# multiplicative and thus the masked values (i.e. the zeros) will
# still take part in the following softmax. So, we need to use the bias
# argument for the mask to ensure the masked values are very small.
# TODO(kaixih@nvidia): The logic should be moved to the internal of
# cudnn_dot_product_attention.
if bias is None:
bias = mask
elif mask is not None:
bias = bias + mask
return cudnn_dot_product_attention(
query, key, value, bias, mask=None, scale=scale_val, # type: ignore[arg-type]
query, key, value, bias, mask, scale=scale_val, # type: ignore[arg-type]
mask_type=mask_type,
)
case None: