mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Force float32 matmuls in examples_test.
This test started failing when we changed our CI to use L4 GPUs. Using highest precision resolves the problem.
This commit is contained in:
parent
c2d78abfa3
commit
24b47318bd
@ -23,7 +23,6 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
@ -55,12 +54,13 @@ class ExamplesTest(parameterized.TestCase):
|
||||
kernel = lambda x, y: jnp.dot(x, y)
|
||||
np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testKernelRegressionTrainAndPredict(self):
|
||||
n, d = 100, 20
|
||||
truth = self.rng.normal(size=d)
|
||||
xs = self.rng.normal(size=(n, d))
|
||||
ys = jnp.dot(xs, truth)
|
||||
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
|
||||
kernel = lambda x, y: jnp.dot(x, y)
|
||||
predict = kernel_lsq.train(kernel, xs, ys)
|
||||
np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user