implement jet rules by lowering to other primitives (#2816)

merge jet_test

add jet rules

use lax.square
This commit is contained in:
Jacob Kelly 2020-04-24 01:07:35 -04:00 committed by GitHub
parent 251834367f
commit fc4203c38a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 0 deletions

View File

@ -293,6 +293,41 @@ def _log_taylor(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.log_p] = _log_taylor
def _sqrt_taylor(primals_in, series_in):
return jet(lambda x: x ** 0.5, primals_in, series_in)
jet_rules[lax.sqrt_p] = _sqrt_taylor
def _rsqrt_taylor(primals_in, series_in):
return jet(lambda x: x ** -0.5, primals_in, series_in)
jet_rules[lax.rsqrt_p] = _rsqrt_taylor
def _asinh_taylor(primals_in, series_in):
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)), primals_in, series_in)
jet_rules[lax.asinh_p] = _asinh_taylor
def _acosh_taylor(primals_in, series_in):
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)), primals_in, series_in)
jet_rules[lax.acosh_p] = _acosh_taylor
def _atanh_taylor(primals_in, series_in):
return jet(lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)), primals_in, series_in)
jet_rules[lax.atanh_p] = _atanh_taylor
def _atan2_taylor(primals_in, series_in):
x, y = primals_in
primal_out = lax.atan2(x, y)
x, series = jet(lax.div, primals_in, series_in)
c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
c = [c0] + cs
u = [x] + series
v = [primal_out] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.atan2_p] = _atan2_taylor
def _log1p_taylor(primals_in, series_in):
x, = primals_in
series, = series_in

View File

@ -215,6 +215,16 @@ class JetTest(jtu.JaxTestCase):
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_sqrt(self): self.unary_check(np.sqrt, lims=[0, 5.])
@jtu.skip_on_devices("tpu")
def test_rsqrt(self): self.unary_check(lax.rsqrt, lims=[0, 5000.])
@jtu.skip_on_devices("tpu")
def test_asinh(self): self.unary_check(lax.asinh, lims=[-100, 100])
@jtu.skip_on_devices("tpu")
def test_acosh(self): self.unary_check(lax.acosh, lims=[-100, 100])
@jtu.skip_on_devices("tpu")
def test_atanh(self): self.unary_check(lax.atanh, lims=[-1, 1])
@jtu.skip_on_devices("tpu")
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
@ -245,6 +255,8 @@ class JetTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
@jtu.ignore_warning(message="overflow encountered in power")
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
@jtu.skip_on_devices("tpu")
def test_atan2(self): self.binary_check(lax.atan2, lims=[-40, 40])
def test_process_call(self):
def f(x):