mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
remove scipy dep, fix dtype issue
This commit is contained in:
parent
8d402d83da
commit
a00e3986d4
@ -17,7 +17,6 @@ from functools import partial
|
||||
from collections import Counter
|
||||
|
||||
import numpy as onp
|
||||
from scipy.special import factorial as fact
|
||||
|
||||
from jax import core
|
||||
from jax.util import unzip2, prod
|
||||
@ -140,6 +139,9 @@ def linear_prop(prim, primals_in, series_in, **params):
|
||||
|
||||
from jax.lax import lax
|
||||
|
||||
def fact(n):
|
||||
return lax.exp(lax.lgamma(n+1.))
|
||||
|
||||
deflinear(lax.neg_p)
|
||||
deflinear(lax.real_p)
|
||||
deflinear(lax.complex_p)
|
||||
|
@ -20,7 +20,6 @@ from unittest import SkipTest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as onp
|
||||
from scipy.special import factorial as fact
|
||||
|
||||
from jax import core
|
||||
from jax import test_util as jtu
|
||||
@ -29,7 +28,7 @@ import jax.numpy as np
|
||||
from jax import random
|
||||
from jax import jacobian, jit
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.jet import jet
|
||||
from jax.experimental.jet import jet, fact
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -53,12 +52,14 @@ def repeated(f, n):
|
||||
|
||||
class JetTest(jtu.JaxTestCase):
|
||||
|
||||
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5):
|
||||
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
|
||||
check_dtypes=True):
|
||||
y, terms = jet(fun, primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=True)
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=True)
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_exp(self):
|
||||
@ -99,19 +100,19 @@ class JetTest(jtu.JaxTestCase):
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
|
||||
x = rng.randn(*input_shape)
|
||||
x = rng.randn(*input_shape).astype("float32")
|
||||
primals = (W, b, x)
|
||||
|
||||
series_in1 = [rng.randn(*W.shape) for _ in range(order)]
|
||||
series_in2 = [rng.randn(*b.shape) for _ in range(order)]
|
||||
series_in3 = [rng.randn(*x.shape) for _ in range(order)]
|
||||
series_in1 = [rng.randn(*W.shape).astype("float32") for _ in range(order)]
|
||||
series_in2 = [rng.randn(*b.shape).astype("float32") for _ in range(order)]
|
||||
series_in3 = [rng.randn(*x.shape).astype("float32") for _ in range(order)]
|
||||
|
||||
series_in = (series_in1, series_in2, series_in3)
|
||||
|
||||
def f(W, b, x):
|
||||
return apply_fun((W, b), x)
|
||||
|
||||
self.check_jet(f, primals, series_in)
|
||||
self.check_jet(f, primals, series_in, check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user