add jets for sines fns (#2892)

refactor

remove duplicate
This commit is contained in:
Jacob Kelly 2020-04-29 22:18:21 -04:00 committed by GitHub
parent c8d1700bd3
commit 1f7ebabfc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 0 deletions

View File

@ -354,6 +354,26 @@ def _div_taylor_rule(primals_in, series_in, **params):
return primal_out, series_out
jet_rules[lax.div_p] = _div_taylor_rule
def _sinusoidal_rule(sign, prims, primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
s, c = prims
s = [s(x)] + [None] * len(series)
c = [c(x)] + [None] * len(series)
for k in range(1, len(s)):
s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
return (s[0], s[1:]), (c[0], c[1:])
def _get_ind(f, ind):
return lambda *args: f(*args)[ind]
jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)
def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
x, y = primals_in
x_terms, y_terms = series_in

View File

@ -210,6 +210,14 @@ class JetTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_expm1(self): self.unary_check(np.expm1)
@jtu.skip_on_devices("tpu")
def test_sin(self): self.unary_check(np.sin)
@jtu.skip_on_devices("tpu")
def test_cos(self): self.unary_check(np.cos)
@jtu.skip_on_devices("tpu")
def test_sinh(self): self.unary_check(np.sinh)
@jtu.skip_on_devices("tpu")
def test_cosh(self): self.unary_check(np.cosh)
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(np.tanh, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)