diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 2769b2af1..4779c8a0b 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -908,26 +908,8 @@ def dot_product_attention( ) case 'cudnn': mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK - # Convert bool mask to float mask for addition - if mask is not None: - large_negative_number = _get_large_negative(query.dtype) - mask = jnp.where(mask, jnp.zeros((), query.dtype), - large_negative_number) - - # Prepare the bias for cudnn flash attention: - # We should never use the mask argument of cudnn, because it is - # multiplicative and thus the masked values (i.e. the zeros) will - # still take part in the following softmax. So, we need to use the bias - # argument for the mask to ensure the masked values are very small. - # TODO(kaixih@nvidia): The logic should be moved to the internal of - # cudnn_dot_product_attention. - if bias is None: - bias = mask - elif mask is not None: - bias = bias + mask - return cudnn_dot_product_attention( - query, key, value, bias, mask=None, scale=scale_val, # type: ignore[arg-type] + query, key, value, bias, mask, scale=scale_val, # type: ignore[arg-type] mask_type=mask_type, ) case None: