mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
implement jet rules by lowering to other primitives (#2816)
merge jet_test add jet rules use lax.square
This commit is contained in:
parent
251834367f
commit
fc4203c38a
@ -293,6 +293,41 @@ def _log_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.log_p] = _log_taylor
|
||||
|
||||
def _sqrt_taylor(primals_in, series_in):
|
||||
return jet(lambda x: x ** 0.5, primals_in, series_in)
|
||||
jet_rules[lax.sqrt_p] = _sqrt_taylor
|
||||
|
||||
def _rsqrt_taylor(primals_in, series_in):
|
||||
return jet(lambda x: x ** -0.5, primals_in, series_in)
|
||||
jet_rules[lax.rsqrt_p] = _rsqrt_taylor
|
||||
|
||||
def _asinh_taylor(primals_in, series_in):
|
||||
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)), primals_in, series_in)
|
||||
jet_rules[lax.asinh_p] = _asinh_taylor
|
||||
|
||||
def _acosh_taylor(primals_in, series_in):
|
||||
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)), primals_in, series_in)
|
||||
jet_rules[lax.acosh_p] = _acosh_taylor
|
||||
|
||||
def _atanh_taylor(primals_in, series_in):
|
||||
return jet(lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)), primals_in, series_in)
|
||||
jet_rules[lax.atanh_p] = _atanh_taylor
|
||||
|
||||
def _atan2_taylor(primals_in, series_in):
|
||||
x, y = primals_in
|
||||
primal_out = lax.atan2(x, y)
|
||||
|
||||
x, series = jet(lax.div, primals_in, series_in)
|
||||
c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
|
||||
c = [c0] + cs
|
||||
u = [x] + series
|
||||
v = [primal_out] + [None] * len(series)
|
||||
for k in range(1, len(v)):
|
||||
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.atan2_p] = _atan2_taylor
|
||||
|
||||
def _log1p_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
|
@ -215,6 +215,16 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sqrt(self): self.unary_check(np.sqrt, lims=[0, 5.])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_rsqrt(self): self.unary_check(lax.rsqrt, lims=[0, 5000.])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_asinh(self): self.unary_check(lax.asinh, lims=[-100, 100])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_acosh(self): self.unary_check(lax.acosh, lims=[-100, 100])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_atanh(self): self.unary_check(lax.atanh, lims=[-1, 1])
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
|
||||
@ -245,6 +255,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jtu.ignore_warning(message="overflow encountered in power")
|
||||
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_atan2(self): self.binary_check(lax.atan2, lims=[-40, 40])
|
||||
|
||||
def test_process_call(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user