From dd023e266e6616494cdd590959103ddc646109c4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 18 Apr 2023 15:56:32 -0700 Subject: [PATCH] jax.scipy.special: fix gradient for xlogy & xlog1py --- jax/_src/scipy/special.py | 17 +++++++++++++++++ tests/lax_scipy_test.py | 23 ++++++++++++++++++----- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index f8f6d3f4b..1e732c021 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -114,23 +114,40 @@ def expit(x: ArrayLike) -> Array: logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp) +@custom_derivatives.custom_jvp @_wraps(osp_special.xlogy, module='scipy.special') def xlogy(x: ArrayLike, y: ArrayLike) -> Array: + # Note: xlogy(0, 0) should return 0 according to the function documentation. x, y = promote_args_inexact("xlogy", x, y) x_ok = x != 0. safe_x = jnp.where(x_ok, x, 1.) safe_y = jnp.where(x_ok, y, 1.) return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x)) +def _xlogy_jvp(primals, tangents): + (x, y) = primals + (x_dot, y_dot) = tangents + result = xlogy(x, y) + return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype) +xlogy.defjvp(_xlogy_jvp) + +@custom_derivatives.custom_jvp @_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False) def xlog1py(x: ArrayLike, y: ArrayLike) -> Array: + # Note: xlog1py(0, -1) should return 0 according to the function documentation. x, y = promote_args_inexact("xlog1py", x, y) x_ok = x != 0. safe_x = jnp.where(x_ok, x, 1.) safe_y = jnp.where(x_ok, y, 1.) return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)), jnp.zeros_like(x)) +def _xlog1py_jvp(primals, tangents): + (x, y) = primals + (x_dot, y_dot) = tangents + result = xlog1py(x, y) + return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype) +xlog1py.defjvp(_xlog1py_jvp) @_wraps(osp_special.entr, module='scipy.special') def entr(x: ArrayLike) -> Array: diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index ca431eb9b..84d5da5ff 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -13,7 +13,6 @@ # limitations under the License. -import functools from functools import partial import itertools import unittest @@ -220,15 +219,29 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): - partial_xlogy = functools.partial(lsp_special.xlogy, 0.) - self.assertAllClose(jax.grad(partial_xlogy)(0.), 0., check_dtypes=False) + # https://github.com/google/jax/issues/15598 + x0, y0 = 0.0, 3.0 + d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0) + self.assertAllClose(d_xlog1py_dx, lax.log(y0)) + + d_xlog1py_dy = jax.grad(lsp_special.xlogy, argnums=1)(x0, y0) + self.assertAllClose(d_xlog1py_dy, 0.0) + + jtu.check_grads(lsp_special.xlogy, (x0, y0), order=2) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): - partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) - self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) + # https://github.com/google/jax/issues/15598 + x0, y0 = 0.0, 3.0 + d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0) + self.assertAllClose(d_xlog1py_dx, lax.log1p(y0)) + + d_xlog1py_dy = jax.grad(lsp_special.xlog1py, argnums=1)(x0, y0) + self.assertAllClose(d_xlog1py_dy, 0.0) + + jtu.check_grads(lsp_special.xlog1py, (x0, y0), order=2) @jtu.sample_product( [dict(order=order, z=z, n_iter=n_iter)