remove scipy dep, fix dtype issue

This commit is contained in:
Matthew Johnson 2020-03-15 12:00:44 -07:00
parent 8d402d83da
commit a00e3986d4
2 changed files with 14 additions and 11 deletions

View File

@ -17,7 +17,6 @@ from functools import partial
from collections import Counter
import numpy as onp
from scipy.special import factorial as fact
from jax import core
from jax.util import unzip2, prod
@ -140,6 +139,9 @@ def linear_prop(prim, primals_in, series_in, **params):
from jax.lax import lax
def fact(n):
return lax.exp(lax.lgamma(n+1.))
deflinear(lax.neg_p)
deflinear(lax.real_p)
deflinear(lax.complex_p)

View File

@ -20,7 +20,6 @@ from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
from scipy.special import factorial as fact
from jax import core
from jax import test_util as jtu
@ -29,7 +28,7 @@ import jax.numpy as np
from jax import random
from jax import jacobian, jit
from jax.experimental import stax
from jax.experimental.jet import jet
from jax.experimental.jet import jet, fact
from jax.config import config
config.parse_flags_with_absl()
@ -53,12 +52,14 @@ def repeated(f, n):
class JetTest(jtu.JaxTestCase):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=True)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=True)
check_dtypes=check_dtypes)
@jtu.skip_on_devices("tpu")
def test_exp(self):
@ -99,19 +100,19 @@ class JetTest(jtu.JaxTestCase):
rng = onp.random.RandomState(0)
x = rng.randn(*input_shape)
x = rng.randn(*input_shape).astype("float32")
primals = (W, b, x)
series_in1 = [rng.randn(*W.shape) for _ in range(order)]
series_in2 = [rng.randn(*b.shape) for _ in range(order)]
series_in3 = [rng.randn(*x.shape) for _ in range(order)]
series_in1 = [rng.randn(*W.shape).astype("float32") for _ in range(order)]
series_in2 = [rng.randn(*b.shape).astype("float32") for _ in range(order)]
series_in3 = [rng.randn(*x.shape).astype("float32") for _ in range(order)]
series_in = (series_in1, series_in2, series_in3)
def f(W, b, x):
return apply_fun((W, b), x)
self.check_jet(f, primals, series_in)
self.check_jet(f, primals, series_in, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_div(self):