[XLA:GPU] Fix misspelled cuDNN custom call targets.

PiperOrigin-RevId: 609024769
This commit is contained in:
Benjamin Chetioui 2024-02-21 09:34:29 -08:00 committed by jax authors
parent 3ce5bd8aa5
commit 5da43a4c55

View File

@ -156,23 +156,23 @@ def create_dot_product_attention_backend_config(batch,
# mapping from (is_bwd, has_dropout, has_mask, has_bias) to custom call name
_custom_name_maps = {
# fMHA forward call targets.
(False, False, False, False): "__cudnn$fhmaSoftmax",
(False, False, False, True): "__cudnn$fhmaScaleBiasSoftmax",
(False, False, True, False): "__cudnn$fhmaScaleMaskSoftmax",
(False, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmax",
(False, True, False, False): "__cudnn$fhmaSoftmaxDropout",
(False, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropout",
(False, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropout",
(False, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
(False, False, False, False): "__cudnn$fmhaSoftmax",
(False, False, False, True): "__cudnn$fmhaScaleBiasSoftmax",
(False, False, True, False): "__cudnn$fmhaScaleMaskSoftmax",
(False, False, True, True): "__cudnn$fmhaScaleBiasMaskSoftmax",
(False, True, False, False): "__cudnn$fmhaSoftmaxDropout",
(False, True, False, True): "__cudnn$fmhaScaleBiasSoftmaxDropout",
(False, True, True, False): "__cudnn$fmhaScaleMaskSoftmaxDropout",
(False, True, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxDropout",
# fMHA backward call targets.
(True, False, False, False): "__cudnn$fhmaSoftmaxBackward",
(True, False, False, True): "__cudnn$fhmaScaleBiasSoftmaxBackward",
(True, False, True, False): "__cudnn$fhmaScaleMaskSoftmaxBackward",
(True, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
(True, True, False, False): "__cudnn$fhmaSoftmaxDropoutBackward",
(True, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
(True, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
(True, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
(True, False, False, False): "__cudnn$fmhaSoftmaxBackward",
(True, False, False, True): "__cudnn$fmhaScaleBiasSoftmaxBackward",
(True, False, True, False): "__cudnn$fmhaScaleMaskSoftmaxBackward",
(True, False, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxBackward",
(True, True, False, False): "__cudnn$fmhaSoftmaxDropoutBackward",
(True, True, False, True): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward",
(True, True, True, False): "__cudnn$fmhaScaleMaskSoftmaxDropoutBackward",
(True, True, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxDropoutBackward"
}
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):