From 24b47318bd85456f65f4216872ba7d55469c176e Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins@google.com>
Date: Fri, 10 May 2024 17:47:43 +0000
Subject: [PATCH] Force float32 matmuls in examples_test.

This test started failing when we changed our CI to use L4 GPUs. Using
highest precision resolves the problem.
---
 examples/examples_test.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/examples/examples_test.py b/examples/examples_test.py
index fd705a4ef..007e8e658 100644
--- a/examples/examples_test.py
+++ b/examples/examples_test.py
@@ -23,7 +23,6 @@ from absl.testing import parameterized
 import numpy as np
 
 import jax
-from jax import lax
 from jax import random
 import jax.numpy as jnp
 from jax._src import test_util as jtu
@@ -55,12 +54,13 @@ class ExamplesTest(parameterized.TestCase):
     kernel = lambda x, y: jnp.dot(x, y)
     np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5)
 
+  @jax.default_matmul_precision("float32")
   def testKernelRegressionTrainAndPredict(self):
     n, d = 100, 20
     truth = self.rng.normal(size=d)
     xs = self.rng.normal(size=(n, d))
     ys = jnp.dot(xs, truth)
-    kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
+    kernel = lambda x, y: jnp.dot(x, y)
     predict = kernel_lsq.train(kernel, xs, ys)
     np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)