mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #6019 from skye:examples_test
PiperOrigin-RevId: 362363477
This commit is contained in:
commit
077793cd64
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
@ -88,7 +89,7 @@ class ExamplesTest(jtu.JaxTestCase):
|
||||
truth = rng.randn(d)
|
||||
xs = rng.randn(n, d)
|
||||
ys = jnp.dot(xs, truth)
|
||||
kernel = lambda x, y: jnp.dot(x, y)
|
||||
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
|
||||
predict = kernel_lsq.train(kernel, xs, ys)
|
||||
self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3,
|
||||
check_dtypes=False)
|
||||
|
@ -19,7 +19,7 @@ import numpy.random as npr
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import optimizers
|
||||
from jax import grad, jit, make_jaxpr, vmap
|
||||
from jax import grad, jit, make_jaxpr, vmap, lax
|
||||
|
||||
|
||||
def gram(kernel, xs):
|
||||
@ -74,7 +74,7 @@ if __name__ == "__main__":
|
||||
|
||||
# linear kernel
|
||||
|
||||
linear_kernel = lambda x, y: jnp.dot(x, y)
|
||||
linear_kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
|
||||
truth = npr.randn(d)
|
||||
xs = npr.randn(n, d)
|
||||
ys = jnp.dot(xs, truth)
|
||||
|
Loading…
x
Reference in New Issue
Block a user