add expm1 and log1p

This commit is contained in:
Jacob Kelly 2020-04-07 17:55:07 -04:00
parent 1cf708ea77
commit 4d7b63c5ec
2 changed files with 27 additions and 0 deletions

View File

@ -228,6 +228,17 @@ def _exp_taylor(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.exp_p] = _exp_taylor
def _expm1_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.exp(x)] + [None] * len(series)
for k in range(1,len(v)):
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
primal_out, *series_out = v
return lax.expm1(x), series_out
jet_rules[lax.expm1_p] = _expm1_taylor
def _log_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
@ -240,6 +251,18 @@ def _log_taylor(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.log_p] = _log_taylor
def _log1p_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x + 1] + series
v = [lax.log(x + 1)] + [None] * len(series)
for k in range(1, len(v)):
conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.log1p_p] = _log1p_taylor
def _div_taylor_rule(primals_in, series_in, **params):
x, y = primals_in
x_terms, y_terms = series_in

View File

@ -150,6 +150,10 @@ class JetTest(jtu.JaxTestCase):
def test_abs(self): self.unary_check(np.abs)
@jtu.skip_on_devices("tpu")
def test_fft(self): self.unary_check(np.fft.fft)
@jtu.skip_on_devices("tpu")
def test_log1p(self): self.unary_check(np.log1p, lims=[0, 4.])
@jtu.skip_on_devices("tpu")
def test_expm1(self): self.unary_check(np.expm1)
@jtu.skip_on_devices("tpu")
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])