mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[x64] make x64_context_test more robust
This commit is contained in:
parent
6cc7d67484
commit
c5c78b5f6d
@ -20,6 +20,7 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
@ -49,7 +50,7 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in ["python", "cpp", None]))
|
||||
def test_make_array(self, jit):
|
||||
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
|
||||
func = _maybe_jit(jit, lambda: jnp.array(np.float64(0)))
|
||||
dtype_start = func().dtype
|
||||
with enable_x64():
|
||||
self.assertEqual(func().dtype, "float64")
|
||||
@ -67,7 +68,7 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
# The fact we defined a jitted function with a block with a different value
|
||||
# of `config.enable_x64` has no impact on the output.
|
||||
with enable_or_disable():
|
||||
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
|
||||
func = _maybe_jit(jit, lambda: jnp.array(np.float64(0)))
|
||||
func()
|
||||
|
||||
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
|
||||
@ -78,28 +79,26 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
with disable_x64():
|
||||
self.assertEqual(func().dtype, "float32")
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != "cpu", "Test presumes CPU precision")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in ["python", "cpp", None]))
|
||||
def test_near_singular_inverse(self, jit):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("64-bit inverse not available on TPU")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
@partial(_maybe_jit, jit, static_argnums=1)
|
||||
def near_singular_inverse(key, N, eps):
|
||||
X = random.uniform(key, (N, N))
|
||||
def near_singular_inverse(N=5, eps=1E-40):
|
||||
X = rng((N, N), dtype='float64')
|
||||
X = jnp.asarray(X)
|
||||
X = X.at[-1].mul(eps)
|
||||
return jnp.linalg.inv(X)
|
||||
|
||||
key = random.PRNGKey(1701)
|
||||
eps = 1E-40
|
||||
N = 5
|
||||
|
||||
with enable_x64():
|
||||
result_64 = near_singular_inverse(key, N, eps)
|
||||
result_64 = near_singular_inverse()
|
||||
self.assertTrue(jnp.all(jnp.isfinite(result_64)))
|
||||
|
||||
with disable_x64():
|
||||
result_32 = near_singular_inverse(key, N, eps)
|
||||
result_32 = near_singular_inverse()
|
||||
self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -120,12 +119,12 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
def func_x32():
|
||||
with disable_x64():
|
||||
time.sleep(0.1)
|
||||
return jnp.arange(10).dtype
|
||||
return jnp.array(np.int64(0)).dtype
|
||||
|
||||
def func_x64():
|
||||
with enable_x64():
|
||||
time.sleep(0.1)
|
||||
return jnp.arange(10).dtype
|
||||
return jnp.array(np.int64(0)).dtype
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
x32 = executor.submit(func_x32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user