mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
c8d1700bd3
commit
1f7ebabfc8
@ -354,6 +354,26 @@ def _div_taylor_rule(primals_in, series_in, **params):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.div_p] = _div_taylor_rule
|
||||
|
||||
def _sinusoidal_rule(sign, prims, primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
s, c = prims
|
||||
s = [s(x)] + [None] * len(series)
|
||||
c = [c(x)] + [None] * len(series)
|
||||
for k in range(1, len(s)):
|
||||
s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
|
||||
c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
|
||||
return (s[0], s[1:]), (c[0], c[1:])
|
||||
|
||||
def _get_ind(f, ind):
|
||||
return lambda *args: f(*args)[ind]
|
||||
|
||||
jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
|
||||
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
|
||||
jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
|
||||
jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)
|
||||
|
||||
def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
|
||||
x, y = primals_in
|
||||
x_terms, y_terms = series_in
|
||||
|
@ -210,6 +210,14 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expm1(self): self.unary_check(np.expm1)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sin(self): self.unary_check(np.sin)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_cos(self): self.unary_check(np.cos)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sinh(self): self.unary_check(np.sinh)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_cosh(self): self.unary_check(np.cosh)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_tanh(self): self.unary_check(np.tanh, lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
|
||||
|
Loading…
x
Reference in New Issue
Block a user