Force float32 matmuls in examples_test.

This test started failing when we changed our CI to use L4 GPUs. Using
highest precision resolves the problem.
This commit is contained in:
Peter Hawkins 2024-05-10 17:47:43 +00:00
parent c2d78abfa3
commit 24b47318bd

View File

@ -23,7 +23,6 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax._src import test_util as jtu
@ -55,12 +54,13 @@ class ExamplesTest(parameterized.TestCase):
kernel = lambda x, y: jnp.dot(x, y)
np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5)
@jax.default_matmul_precision("float32")
def testKernelRegressionTrainAndPredict(self):
n, d = 100, 20
truth = self.rng.normal(size=d)
xs = self.rng.normal(size=(n, d))
ys = jnp.dot(xs, truth)
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
kernel = lambda x, y: jnp.dot(x, y)
predict = kernel_lsq.train(kernel, xs, ys)
np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)