add tanh rule (#2653)

change expit taylor rule

add manual expit check, check stability of expit and tanh
This commit is contained in:
Jacob Kelly 2020-04-22 20:49:10 -04:00 committed by GitHub
parent 8fe3c59ced
commit 59bdb1fb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 0 deletions

View File

@ -17,6 +17,7 @@ from functools import partial
import numpy as onp
import jax
from jax import core
from jax.util import unzip2
from jax.tree_util import (register_pytree_node, tree_structure,
@ -217,6 +218,9 @@ def fact(n):
def _scale(k, j):
return 1. / (fact(k - j) * fact(j - 1))
def _scale2(k, j):
return 1. / (fact(k - j) * fact(j))
def _exp_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
@ -253,6 +257,28 @@ def _pow_taylor(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.pow_p] = _pow_taylor
def _expit_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [jax.scipy.special.expit(x)] + [None] * len(series)
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)])
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)])
primal_out, *series_out = v
return primal_out, series_out
def _tanh_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
primals_in, *series_in = u
primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
series_out = [2 * series_ for series_ in series_out]
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor
def _log_taylor(primals_in, series_in):
x, = primals_in

View File

@ -20,6 +20,7 @@ import numpy as onp
from jax import test_util as jtu
import jax.numpy as np
import jax.scipy.special
from jax import random
from jax import jacfwd, jit
from jax.experimental import stax
@ -153,6 +154,27 @@ class JetTest(jtu.JaxTestCase):
else:
self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
def expit_check(self, lims=[-2, 2], order=3):
dims = 2, 3
rng = onp.random.RandomState(0)
primal_in = transform(lims, rng.rand(*dims))
terms_in = [rng.randn(*dims) for _ in range(order)]
primals = (primal_in, )
series = (terms_in, )
y, terms = jax.experimental.jet._expit_taylor(primals, series)
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)
atol = 1e-4
rtol = 1e-4
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=True)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=True)
@jtu.skip_on_devices("tpu")
def test_exp(self): self.unary_check(np.exp)
@jtu.skip_on_devices("tpu")
@ -187,6 +209,12 @@ class JetTest(jtu.JaxTestCase):
def test_log1p(self): self.unary_check(np.log1p, lims=[0, 4.])
@jtu.skip_on_devices("tpu")
def test_expm1(self): self.unary_check(np.expm1)
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(np.tanh, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])