mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix parenthesis in "Gradients contain NaN where using where"
This commit is contained in:
parent
8deed95c7f
commit
7a8bcf6ee5
@ -686,7 +686,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.::
|
||||
|
||||
def my_log_or_y(x, y):
|
||||
"""Return log(x) if x > 0 or y"""
|
||||
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.), y)
|
||||
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)
|
||||
|
||||
|
||||
Additional reading:
|
||||
@ -849,4 +849,4 @@ see the page on `JAX GPU memory allocation`_.
|
||||
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
|
||||
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266
|
||||
.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
|
||||
.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html
|
||||
.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html
|
||||
|
Loading…
x
Reference in New Issue
Block a user