Merge pull request #6019 from skye:examples_test

PiperOrigin-RevId: 362363477
This commit is contained in:
jax authors 2021-03-11 13:23:53 -08:00
commit 077793cd64
2 changed files with 4 additions and 3 deletions

View File

@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from jax import lax
from jax import test_util as jtu
from jax import random
import jax.numpy as jnp
@ -88,7 +89,7 @@ class ExamplesTest(jtu.JaxTestCase):
truth = rng.randn(d)
xs = rng.randn(n, d)
ys = jnp.dot(xs, truth)
kernel = lambda x, y: jnp.dot(x, y)
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
predict = kernel_lsq.train(kernel, xs, ys)
self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3,
check_dtypes=False)

View File

@ -19,7 +19,7 @@ import numpy.random as npr
import jax.numpy as jnp
from jax.experimental import optimizers
from jax import grad, jit, make_jaxpr, vmap
from jax import grad, jit, make_jaxpr, vmap, lax
def gram(kernel, xs):
@ -74,7 +74,7 @@ if __name__ == "__main__":
# linear kernel
linear_kernel = lambda x, y: jnp.dot(x, y)
linear_kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
truth = npr.randn(d)
xs = npr.randn(n, d)
ys = jnp.dot(xs, truth)