Fixes the JAX implementation of CELU returning NaN gradients for input

values >= 88.7229.

When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:

https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.

PiperOrigin-RevId: 461678140
This commit is contained in:
jax authors 2022-07-18 11:57:29 -07:00
parent d98d5ddce5
commit b0805a8a31

View File

@ -200,7 +200,7 @@ def celu(x: Array, alpha: Array = 1.0) -> Array:
x : input array
alpha : array or scalar (default: 1.0)
"""
return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
@jax.jit
def selu(x: Array) -> Array: