disable implicit rank promotion for api_test

This commit is contained in:
Jake VanderPlas 2022-01-24 11:46:25 -08:00
parent 16c809ce7f
commit 3197aacbfc

View File

@ -69,6 +69,7 @@ python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CPPJitTest(jtu.BufferDonationTestCase):
"""Shared tests between the Python and the C++ jax,jit implementations.
@ -829,6 +830,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f({E.A: 1.0, E.B: 2.0})
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PythonJitTest(CPPJitTest):
@property
@ -836,6 +838,7 @@ class PythonJitTest(CPPJitTest):
return api._python_jit
@jtu.with_config(jax_numpy_rank_promotion="raise")
class APITest(jtu.JaxTestCase):
def test_grad_item(self):
@ -1759,7 +1762,7 @@ class APITest(jtu.JaxTestCase):
self.dtype = np.dtype(dtype)
A = MyArgArray((3, 4), jnp.float32)
b = MyArgArray((5,), jnp.float32)
b = MyArgArray((1, 5), jnp.float32)
x = MyArgArray((4, 5), jnp.float32)
out_shape = api.eval_shape(fun, A, b, x)
@ -3193,50 +3196,6 @@ class APITest(jtu.JaxTestCase):
finally:
FLAGS.jax_default_matmul_precision = precision
def test_rank_promotion_forces_retrace(self):
num_traces = 0
def g(x):
nonlocal num_traces
num_traces += 1
return x + x
def f_cond(x):
return lax.cond(True, g, g, x)
@jax.jit
def f_jit(x):
nonlocal num_traces
num_traces += 1
return x + x
for f in [f_jit, f_cond]:
allow_promotion = config.jax_numpy_rank_promotion
try:
num_traces = 0
@jax.jit
def f(x):
nonlocal num_traces
num_traces += 1
return x + x
x = jnp.zeros((2, 2))
f(x)
self.assertEqual(num_traces, 1)
f(x)
self.assertEqual(num_traces, 1)
with jax.numpy_rank_promotion("warn"):
f(x)
self.assertEqual(num_traces, 2)
FLAGS.jax_numpy_rank_promotion = "raise"
f(x)
self.assertGreaterEqual(num_traces, 2)
nt = num_traces
f(x)
self.assertEqual(num_traces, nt + 1)
f(x)
self.assertEqual(num_traces, nt + 1)
finally:
FLAGS.jax_numpy_rank_promotion = allow_promotion
def test_backward_pass_ref_dropping(self):
refs = []
@ -3333,7 +3292,7 @@ class APITest(jtu.JaxTestCase):
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
x = jnp.ones((50,))
x = jnp.ones((1, 50))
A = jnp.array(npr.randn(50, 50))
@jax.jit
@ -3370,6 +3329,53 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count[0], 0)
class RankPromotionTest(jtu.JaxTestCase):
def test_rank_promotion_forces_retrace(self):
num_traces = 0
def g(x):
nonlocal num_traces
num_traces += 1
return x + x
def f_cond(x):
return lax.cond(True, g, g, x)
@jax.jit
def f_jit(x):
nonlocal num_traces
num_traces += 1
return x + x
for f in [f_jit, f_cond]:
allow_promotion = config.jax_numpy_rank_promotion
try:
num_traces = 0
@jax.jit
def f(x):
nonlocal num_traces
num_traces += 1
return x + x
x = jnp.zeros((2, 2))
f(x)
self.assertEqual(num_traces, 1)
f(x)
self.assertEqual(num_traces, 1)
with jax.numpy_rank_promotion("warn"):
f(x)
self.assertEqual(num_traces, 2)
FLAGS.jax_numpy_rank_promotion = "raise"
f(x)
self.assertGreaterEqual(num_traces, 2)
nt = num_traces
f(x)
self.assertEqual(num_traces, nt + 1)
f(x)
self.assertEqual(num_traces, nt + 1)
finally:
FLAGS.jax_numpy_rank_promotion = allow_promotion
@jtu.with_config(jax_numpy_rank_promotion="raise")
class RematTest(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -4226,6 +4232,7 @@ class RematTest(jtu.JaxTestCase):
_ = api.linearize(partial(f, core.unit), 3.)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class JaxprTest(jtu.JaxTestCase):
def test_scalar_literals(self):
@ -4369,6 +4376,7 @@ class JaxprTest(jtu.JaxTestCase):
self.assertLen(jaxpr.eqns, 0)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomJVPTest(jtu.JaxTestCase):
def test_basic(self):
@ -5311,7 +5319,7 @@ class CustomJVPTest(jtu.JaxTestCase):
for i in range(2):
buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i)
buf = jnp.matmul(bd_rowsum, buf)
return buf * val
return buf * val[None, :]
# -----
# Vertorizing will raise shape error
bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec)
@ -5343,6 +5351,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertEqual(shape, ())
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomVJPTest(jtu.JaxTestCase):
def test_basic(self):
@ -6311,6 +6320,7 @@ def transpose_unary(f, x_example):
return transposed
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomTransposeTest(jtu.JaxTestCase):
def test_linear_call(self):
@ -6639,6 +6649,7 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomVmapTest(jtu.JaxTestCase):
def test_basic(self):
@ -7023,6 +7034,7 @@ class CustomVmapTest(jtu.JaxTestCase):
self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
@ -7060,6 +7072,7 @@ class CustomApiTest(jtu.JaxTestCase):
self.assertIsInstance(getattr(f, method), Callable)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class InvertibleADTest(jtu.JaxTestCase):
@jtu.ignore_warning(message="Values that an @invertible function closes")
@ -7168,6 +7181,7 @@ class InvertibleADTest(jtu.JaxTestCase):
check_dtypes=True)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class BufferDonationTest(jtu.BufferDonationTestCase):
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
@ -7190,6 +7204,7 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
pmap_fun(a) # doesn't crash
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NamedCallTest(jtu.JaxTestCase):
def test_default_name(self):
@ -7270,6 +7285,7 @@ class NamedCallTest(jtu.JaxTestCase):
self.assertRaises(OverflowError, f, int_min - 1)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class BackendsTest(jtu.JaxTestCase):
@unittest.skipIf(not sys.executable, "test requires sys.executable")
@ -7292,6 +7308,7 @@ class BackendsTest(jtu.JaxTestCase):
assert "No GPU/TPU found" not in result.stderr.decode()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CleanupTest(jtu.JaxTestCase):
def test_call_wrapped_second_phase_cleanup(self):
try: