mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fixes the JAX implementation of CELU returning NaN gradients for input
values >= 88.7229. When a JAX where() op is used to avoid a NaN or undefined value, reverse differentiation can still return NaN even though the NaN input is not selected by the conditional: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value. PiperOrigin-RevId: 461678140
This commit is contained in:
parent
d98d5ddce5
commit
b0805a8a31
@ -200,7 +200,7 @@ def celu(x: Array, alpha: Array = 1.0) -> Array:
|
||||
x : input array
|
||||
alpha : array or scalar (default: 1.0)
|
||||
"""
|
||||
return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))
|
||||
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
|
||||
|
||||
@jax.jit
|
||||
def selu(x: Array) -> Array:
|
||||
|
Loading…
x
Reference in New Issue
Block a user