mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jax.scipy.special: fix gradient for xlogy & xlog1py
This commit is contained in:
parent
7cf86d1577
commit
dd023e266e
@ -114,23 +114,40 @@ def expit(x: ArrayLike) -> Array:
|
||||
logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.xlogy, module='scipy.special')
|
||||
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
# Note: xlogy(0, 0) should return 0 according to the function documentation.
|
||||
x, y = promote_args_inexact("xlogy", x, y)
|
||||
x_ok = x != 0.
|
||||
safe_x = jnp.where(x_ok, x, 1.)
|
||||
safe_y = jnp.where(x_ok, y, 1.)
|
||||
return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
|
||||
|
||||
def _xlogy_jvp(primals, tangents):
|
||||
(x, y) = primals
|
||||
(x_dot, y_dot) = tangents
|
||||
result = xlogy(x, y)
|
||||
return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)
|
||||
xlogy.defjvp(_xlogy_jvp)
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False)
|
||||
def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
# Note: xlog1py(0, -1) should return 0 according to the function documentation.
|
||||
x, y = promote_args_inexact("xlog1py", x, y)
|
||||
x_ok = x != 0.
|
||||
safe_x = jnp.where(x_ok, x, 1.)
|
||||
safe_y = jnp.where(x_ok, y, 1.)
|
||||
return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)), jnp.zeros_like(x))
|
||||
|
||||
def _xlog1py_jvp(primals, tangents):
|
||||
(x, y) = primals
|
||||
(x_dot, y_dot) = tangents
|
||||
result = xlog1py(x, y)
|
||||
return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype)
|
||||
xlog1py.defjvp(_xlog1py_jvp)
|
||||
|
||||
@_wraps(osp_special.entr, module='scipy.special')
|
||||
def entr(x: ArrayLike) -> Array:
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
@ -220,15 +219,29 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)
|
||||
|
||||
def testGradOfXlogyAtZero(self):
|
||||
partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
|
||||
self.assertAllClose(jax.grad(partial_xlogy)(0.), 0., check_dtypes=False)
|
||||
# https://github.com/google/jax/issues/15598
|
||||
x0, y0 = 0.0, 3.0
|
||||
d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0)
|
||||
self.assertAllClose(d_xlog1py_dx, lax.log(y0))
|
||||
|
||||
d_xlog1py_dy = jax.grad(lsp_special.xlogy, argnums=1)(x0, y0)
|
||||
self.assertAllClose(d_xlog1py_dy, 0.0)
|
||||
|
||||
jtu.check_grads(lsp_special.xlogy, (x0, y0), order=2)
|
||||
|
||||
def testXlog1pyShouldReturnZero(self):
|
||||
self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)
|
||||
|
||||
def testGradOfXlog1pyAtZero(self):
|
||||
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
|
||||
self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
|
||||
# https://github.com/google/jax/issues/15598
|
||||
x0, y0 = 0.0, 3.0
|
||||
d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0)
|
||||
self.assertAllClose(d_xlog1py_dx, lax.log1p(y0))
|
||||
|
||||
d_xlog1py_dy = jax.grad(lsp_special.xlog1py, argnums=1)(x0, y0)
|
||||
self.assertAllClose(d_xlog1py_dy, 0.0)
|
||||
|
||||
jtu.check_grads(lsp_special.xlog1py, (x0, y0), order=2)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(order=order, z=z, n_iter=n_iter)
|
||||
|
Loading…
x
Reference in New Issue
Block a user