Fix parenthesis in "Gradients contain NaN where using where"

This commit is contained in:
Marco Selvi 2024-05-31 12:03:39 +01:00 committed by GitHub
parent 8deed95c7f
commit 7a8bcf6ee5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -686,7 +686,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.::
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.), y)
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)
Additional reading:
@ -849,4 +849,4 @@ see the page on `JAX GPU memory allocation`_.
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266
.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html
.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html