mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:GPU] Fix misspelled cuDNN custom call targets.
PiperOrigin-RevId: 609024769
This commit is contained in:
parent
3ce5bd8aa5
commit
5da43a4c55
@ -156,23 +156,23 @@ def create_dot_product_attention_backend_config(batch,
|
||||
# mapping from (is_bwd, has_dropout, has_mask, has_bias) to custom call name
|
||||
_custom_name_maps = {
|
||||
# fMHA forward call targets.
|
||||
(False, False, False, False): "__cudnn$fhmaSoftmax",
|
||||
(False, False, False, True): "__cudnn$fhmaScaleBiasSoftmax",
|
||||
(False, False, True, False): "__cudnn$fhmaScaleMaskSoftmax",
|
||||
(False, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmax",
|
||||
(False, True, False, False): "__cudnn$fhmaSoftmaxDropout",
|
||||
(False, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropout",
|
||||
(False, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropout",
|
||||
(False, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
|
||||
(False, False, False, False): "__cudnn$fmhaSoftmax",
|
||||
(False, False, False, True): "__cudnn$fmhaScaleBiasSoftmax",
|
||||
(False, False, True, False): "__cudnn$fmhaScaleMaskSoftmax",
|
||||
(False, False, True, True): "__cudnn$fmhaScaleBiasMaskSoftmax",
|
||||
(False, True, False, False): "__cudnn$fmhaSoftmaxDropout",
|
||||
(False, True, False, True): "__cudnn$fmhaScaleBiasSoftmaxDropout",
|
||||
(False, True, True, False): "__cudnn$fmhaScaleMaskSoftmaxDropout",
|
||||
(False, True, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxDropout",
|
||||
# fMHA backward call targets.
|
||||
(True, False, False, False): "__cudnn$fhmaSoftmaxBackward",
|
||||
(True, False, False, True): "__cudnn$fhmaScaleBiasSoftmaxBackward",
|
||||
(True, False, True, False): "__cudnn$fhmaScaleMaskSoftmaxBackward",
|
||||
(True, False, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
|
||||
(True, True, False, False): "__cudnn$fhmaSoftmaxDropoutBackward",
|
||||
(True, True, False, True): "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
|
||||
(True, True, True, False): "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
|
||||
(True, True, True, True): "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
|
||||
(True, False, False, False): "__cudnn$fmhaSoftmaxBackward",
|
||||
(True, False, False, True): "__cudnn$fmhaScaleBiasSoftmaxBackward",
|
||||
(True, False, True, False): "__cudnn$fmhaScaleMaskSoftmaxBackward",
|
||||
(True, False, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxBackward",
|
||||
(True, True, False, False): "__cudnn$fmhaSoftmaxDropoutBackward",
|
||||
(True, True, False, True): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward",
|
||||
(True, True, True, False): "__cudnn$fmhaScaleMaskSoftmaxDropoutBackward",
|
||||
(True, True, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxDropoutBackward"
|
||||
}
|
||||
|
||||
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
|
Loading…
x
Reference in New Issue
Block a user