mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add expm1 and log1p
This commit is contained in:
parent
1cf708ea77
commit
4d7b63c5ec
@ -228,6 +228,17 @@ def _exp_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.exp_p] = _exp_taylor
|
||||
|
||||
def _expm1_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
v = [lax.exp(x)] + [None] * len(series)
|
||||
for k in range(1,len(v)):
|
||||
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
primal_out, *series_out = v
|
||||
return lax.expm1(x), series_out
|
||||
jet_rules[lax.expm1_p] = _expm1_taylor
|
||||
|
||||
def _log_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
@ -240,6 +251,18 @@ def _log_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.log_p] = _log_taylor
|
||||
|
||||
def _log1p_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x + 1] + series
|
||||
v = [lax.log(x + 1)] + [None] * len(series)
|
||||
for k in range(1, len(v)):
|
||||
conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
|
||||
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.log1p_p] = _log1p_taylor
|
||||
|
||||
def _div_taylor_rule(primals_in, series_in, **params):
|
||||
x, y = primals_in
|
||||
x_terms, y_terms = series_in
|
||||
|
@ -150,6 +150,10 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_abs(self): self.unary_check(np.abs)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_fft(self): self.unary_check(np.fft.fft)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_log1p(self): self.unary_check(np.log1p, lims=[0, 4.])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expm1(self): self.unary_check(np.expm1)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user