diff --git a/jax/nn/functions.py b/jax/nn/functions.py index d035fba59..5b67c0f2b 100644 --- a/jax/nn/functions.py +++ b/jax/nn/functions.py @@ -51,7 +51,7 @@ def selu(x): """Scaled exponential linear unit activation""" alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 - return scale * leaky_relu(x, alpha) + return scale * elu(x, alpha) @jarrett def gelu(x):