jax.scipy.special: fix gradient for xlogy & xlog1py

This commit is contained in:
Jake VanderPlas 2023-04-18 15:56:32 -07:00
parent 7cf86d1577
commit dd023e266e
2 changed files with 35 additions and 5 deletions

View File

@ -114,23 +114,40 @@ def expit(x: ArrayLike) -> Array:
logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
@custom_derivatives.custom_jvp
@_wraps(osp_special.xlogy, module='scipy.special')
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlogy(0, 0) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlogy", x, y)
x_ok = x != 0.
safe_x = jnp.where(x_ok, x, 1.)
safe_y = jnp.where(x_ok, y, 1.)
return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
def _xlogy_jvp(primals, tangents):
(x, y) = primals
(x_dot, y_dot) = tangents
result = xlogy(x, y)
return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)
xlogy.defjvp(_xlogy_jvp)
@custom_derivatives.custom_jvp
@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False)
def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlog1py(0, -1) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlog1py", x, y)
x_ok = x != 0.
safe_x = jnp.where(x_ok, x, 1.)
safe_y = jnp.where(x_ok, y, 1.)
return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)), jnp.zeros_like(x))
def _xlog1py_jvp(primals, tangents):
(x, y) = primals
(x_dot, y_dot) = tangents
result = xlog1py(x, y)
return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype)
xlog1py.defjvp(_xlog1py_jvp)
@_wraps(osp_special.entr, module='scipy.special')
def entr(x: ArrayLike) -> Array:

View File

@ -13,7 +13,6 @@
# limitations under the License.
import functools
from functools import partial
import itertools
import unittest
@ -220,15 +219,29 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)
def testGradOfXlogyAtZero(self):
partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
self.assertAllClose(jax.grad(partial_xlogy)(0.), 0., check_dtypes=False)
# https://github.com/google/jax/issues/15598
x0, y0 = 0.0, 3.0
d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0)
self.assertAllClose(d_xlog1py_dx, lax.log(y0))
d_xlog1py_dy = jax.grad(lsp_special.xlogy, argnums=1)(x0, y0)
self.assertAllClose(d_xlog1py_dy, 0.0)
jtu.check_grads(lsp_special.xlogy, (x0, y0), order=2)
def testXlog1pyShouldReturnZero(self):
self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)
def testGradOfXlog1pyAtZero(self):
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
# https://github.com/google/jax/issues/15598
x0, y0 = 0.0, 3.0
d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0)
self.assertAllClose(d_xlog1py_dx, lax.log1p(y0))
d_xlog1py_dy = jax.grad(lsp_special.xlog1py, argnums=1)(x0, y0)
self.assertAllClose(d_xlog1py_dy, 0.0)
jtu.check_grads(lsp_special.xlog1py, (x0, y0), order=2)
@jtu.sample_product(
[dict(order=order, z=z, n_iter=n_iter)