diff --git a/examples/examples_test.py b/examples/examples_test.py index 1162ef058..2b2b56fcf 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np +from jax import lax from jax import test_util as jtu from jax import random import jax.numpy as jnp @@ -88,7 +89,7 @@ class ExamplesTest(jtu.JaxTestCase): truth = rng.randn(d) xs = rng.randn(n, d) ys = jnp.dot(xs, truth) - kernel = lambda x, y: jnp.dot(x, y) + kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH) predict = kernel_lsq.train(kernel, xs, ys) self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3, check_dtypes=False) diff --git a/examples/kernel_lsq.py b/examples/kernel_lsq.py index 570828b40..5211f4da0 100644 --- a/examples/kernel_lsq.py +++ b/examples/kernel_lsq.py @@ -19,7 +19,7 @@ import numpy.random as npr import jax.numpy as jnp from jax.experimental import optimizers -from jax import grad, jit, make_jaxpr, vmap +from jax import grad, jit, make_jaxpr, vmap, lax def gram(kernel, xs): @@ -74,7 +74,7 @@ if __name__ == "__main__": # linear kernel - linear_kernel = lambda x, y: jnp.dot(x, y) + linear_kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH) truth = npr.randn(d) xs = npr.randn(n, d) ys = jnp.dot(xs, truth)