mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
PR #22330: [NVIDIA] Remove logic of combining bias and mask
Imported from GitHub PR https://github.com/google/jax/pull/22330 The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078). We are removing the logic from the public sdpa API and pass the mask directly. cc. @Cjkkkk Copybara import of the project: -- 0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>: Remove logic of combining bias and mask Merging this change closes #22330 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a PiperOrigin-RevId: 652830016
This commit is contained in:
parent
7a62b8dd18
commit
09531d2ff8
@ -908,26 +908,8 @@ def dot_product_attention(
|
||||
)
|
||||
case 'cudnn':
|
||||
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
|
||||
# Convert bool mask to float mask for addition
|
||||
if mask is not None:
|
||||
large_negative_number = _get_large_negative(query.dtype)
|
||||
mask = jnp.where(mask, jnp.zeros((), query.dtype),
|
||||
large_negative_number)
|
||||
|
||||
# Prepare the bias for cudnn flash attention:
|
||||
# We should never use the mask argument of cudnn, because it is
|
||||
# multiplicative and thus the masked values (i.e. the zeros) will
|
||||
# still take part in the following softmax. So, we need to use the bias
|
||||
# argument for the mask to ensure the masked values are very small.
|
||||
# TODO(kaixih@nvidia): The logic should be moved to the internal of
|
||||
# cudnn_dot_product_attention.
|
||||
if bias is None:
|
||||
bias = mask
|
||||
elif mask is not None:
|
||||
bias = bias + mask
|
||||
|
||||
return cudnn_dot_product_attention(
|
||||
query, key, value, bias, mask=None, scale=scale_val, # type: ignore[arg-type]
|
||||
query, key, value, bias, mask, scale=scale_val, # type: ignore[arg-type]
|
||||
mask_type=mask_type,
|
||||
)
|
||||
case None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user