[x64] make x64_context_test more robust

This commit is contained in:
Jake VanderPlas 2021-11-22 14:54:30 -08:00
parent 6cc7d67484
commit c5c78b5f6d

View File

@ -20,6 +20,7 @@ import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax._src import api
@ -49,7 +50,7 @@ class X64ContextTests(jtu.JaxTestCase):
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in ["python", "cpp", None]))
def test_make_array(self, jit):
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
func = _maybe_jit(jit, lambda: jnp.array(np.float64(0)))
dtype_start = func().dtype
with enable_x64():
self.assertEqual(func().dtype, "float64")
@ -67,7 +68,7 @@ class X64ContextTests(jtu.JaxTestCase):
# The fact we defined a jitted function with a block with a different value
# of `config.enable_x64` has no impact on the output.
with enable_or_disable():
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
func = _maybe_jit(jit, lambda: jnp.array(np.float64(0)))
func()
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
@ -78,28 +79,26 @@ class X64ContextTests(jtu.JaxTestCase):
with disable_x64():
self.assertEqual(func().dtype, "float32")
@unittest.skipIf(jtu.device_under_test() != "cpu", "Test presumes CPU precision")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in ["python", "cpp", None]))
def test_near_singular_inverse(self, jit):
if jtu.device_under_test() == "tpu":
self.skipTest("64-bit inverse not available on TPU")
rng = jtu.rand_default(self.rng())
@partial(_maybe_jit, jit, static_argnums=1)
def near_singular_inverse(key, N, eps):
X = random.uniform(key, (N, N))
def near_singular_inverse(N=5, eps=1E-40):
X = rng((N, N), dtype='float64')
X = jnp.asarray(X)
X = X.at[-1].mul(eps)
return jnp.linalg.inv(X)
key = random.PRNGKey(1701)
eps = 1E-40
N = 5
with enable_x64():
result_64 = near_singular_inverse(key, N, eps)
result_64 = near_singular_inverse()
self.assertTrue(jnp.all(jnp.isfinite(result_64)))
with disable_x64():
result_32 = near_singular_inverse(key, N, eps)
result_32 = near_singular_inverse()
self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
@parameterized.named_parameters(jtu.cases_from_list(
@ -120,12 +119,12 @@ class X64ContextTests(jtu.JaxTestCase):
def func_x32():
with disable_x64():
time.sleep(0.1)
return jnp.arange(10).dtype
return jnp.array(np.int64(0)).dtype
def func_x64():
with enable_x64():
time.sleep(0.1)
return jnp.arange(10).dtype
return jnp.array(np.int64(0)).dtype
with concurrent.futures.ThreadPoolExecutor() as executor:
x32 = executor.submit(func_x32)