mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add tanh rule (#2653)
change expit taylor rule add manual expit check, check stability of expit and tanh
This commit is contained in:
parent
8fe3c59ced
commit
59bdb1fb3d
@ -17,6 +17,7 @@ from functools import partial
|
||||
|
||||
import numpy as onp
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax.util import unzip2
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
@ -217,6 +218,9 @@ def fact(n):
|
||||
def _scale(k, j):
|
||||
return 1. / (fact(k - j) * fact(j - 1))
|
||||
|
||||
def _scale2(k, j):
|
||||
return 1. / (fact(k - j) * fact(j))
|
||||
|
||||
def _exp_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
@ -253,6 +257,28 @@ def _pow_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.pow_p] = _pow_taylor
|
||||
|
||||
def _expit_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
v = [jax.scipy.special.expit(x)] + [None] * len(series)
|
||||
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
|
||||
for k in range(1, len(v)):
|
||||
v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)])
|
||||
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)])
|
||||
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
|
||||
def _tanh_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [2*x] + [2 * series_ for series_ in series]
|
||||
primals_in, *series_in = u
|
||||
primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
|
||||
series_out = [2 * series_ for series_ in series_out]
|
||||
return 2 * primal_out - 1, series_out
|
||||
jet_rules[lax.tanh_p] = _tanh_taylor
|
||||
|
||||
def _log_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
|
@ -20,6 +20,7 @@ import numpy as onp
|
||||
|
||||
from jax import test_util as jtu
|
||||
import jax.numpy as np
|
||||
import jax.scipy.special
|
||||
from jax import random
|
||||
from jax import jacfwd, jit
|
||||
from jax.experimental import stax
|
||||
@ -153,6 +154,27 @@ class JetTest(jtu.JaxTestCase):
|
||||
else:
|
||||
self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def expit_check(self, lims=[-2, 2], order=3):
|
||||
dims = 2, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = transform(lims, rng.rand(*dims))
|
||||
terms_in = [rng.randn(*dims) for _ in range(order)]
|
||||
|
||||
primals = (primal_in, )
|
||||
series = (terms_in, )
|
||||
|
||||
y, terms = jax.experimental.jet._expit_taylor(primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)
|
||||
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=True)
|
||||
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=True)
|
||||
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_exp(self): self.unary_check(np.exp)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -187,6 +209,12 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_log1p(self): self.unary_check(np.log1p, lims=[0, 4.])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expm1(self): self.unary_check(np.expm1)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_tanh(self): self.unary_check(np.tanh, lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user