mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
disable implicit rank promotion for api_test
This commit is contained in:
parent
16c809ce7f
commit
3197aacbfc
@ -69,6 +69,7 @@ python_version = (sys.version_info[0], sys.version_info[1])
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
"""Shared tests between the Python and the C++ jax,jit implementations.
|
||||
|
||||
@ -829,6 +830,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f({E.A: 1.0, E.B: 2.0})
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
@property
|
||||
@ -836,6 +838,7 @@ class PythonJitTest(CPPJitTest):
|
||||
return api._python_jit
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_item(self):
|
||||
@ -1759,7 +1762,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
A = MyArgArray((3, 4), jnp.float32)
|
||||
b = MyArgArray((5,), jnp.float32)
|
||||
b = MyArgArray((1, 5), jnp.float32)
|
||||
x = MyArgArray((4, 5), jnp.float32)
|
||||
out_shape = api.eval_shape(fun, A, b, x)
|
||||
|
||||
@ -3193,50 +3196,6 @@ class APITest(jtu.JaxTestCase):
|
||||
finally:
|
||||
FLAGS.jax_default_matmul_precision = precision
|
||||
|
||||
def test_rank_promotion_forces_retrace(self):
|
||||
num_traces = 0
|
||||
|
||||
def g(x):
|
||||
nonlocal num_traces
|
||||
num_traces += 1
|
||||
return x + x
|
||||
def f_cond(x):
|
||||
return lax.cond(True, g, g, x)
|
||||
|
||||
@jax.jit
|
||||
def f_jit(x):
|
||||
nonlocal num_traces
|
||||
num_traces += 1
|
||||
return x + x
|
||||
|
||||
for f in [f_jit, f_cond]:
|
||||
allow_promotion = config.jax_numpy_rank_promotion
|
||||
try:
|
||||
num_traces = 0
|
||||
@jax.jit
|
||||
def f(x):
|
||||
nonlocal num_traces
|
||||
num_traces += 1
|
||||
return x + x
|
||||
x = jnp.zeros((2, 2))
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 1)
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 1)
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 2)
|
||||
FLAGS.jax_numpy_rank_promotion = "raise"
|
||||
f(x)
|
||||
self.assertGreaterEqual(num_traces, 2)
|
||||
nt = num_traces
|
||||
f(x)
|
||||
self.assertEqual(num_traces, nt + 1)
|
||||
f(x)
|
||||
self.assertEqual(num_traces, nt + 1)
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = allow_promotion
|
||||
|
||||
def test_backward_pass_ref_dropping(self):
|
||||
refs = []
|
||||
|
||||
@ -3333,7 +3292,7 @@ class APITest(jtu.JaxTestCase):
|
||||
def sigmoid(x):
|
||||
return 1. / (1. + jnp.exp(-x))
|
||||
|
||||
x = jnp.ones((50,))
|
||||
x = jnp.ones((1, 50))
|
||||
A = jnp.array(npr.randn(50, 50))
|
||||
|
||||
@jax.jit
|
||||
@ -3370,6 +3329,53 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
|
||||
class RankPromotionTest(jtu.JaxTestCase):
|
||||
def test_rank_promotion_forces_retrace(self):
|
||||
num_traces = 0
|
||||
|
||||
def g(x):
|
||||
nonlocal num_traces
|
||||
num_traces += 1
|
||||
return x + x
|
||||
def f_cond(x):
|
||||
return lax.cond(True, g, g, x)
|
||||
|
||||
@jax.jit
|
||||
def f_jit(x):
|
||||
nonlocal num_traces
|
||||
num_traces += 1
|
||||
return x + x
|
||||
|
||||
for f in [f_jit, f_cond]:
|
||||
allow_promotion = config.jax_numpy_rank_promotion
|
||||
try:
|
||||
num_traces = 0
|
||||
@jax.jit
|
||||
def f(x):
|
||||
nonlocal num_traces
|
||||
num_traces += 1
|
||||
return x + x
|
||||
x = jnp.zeros((2, 2))
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 1)
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 1)
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 2)
|
||||
FLAGS.jax_numpy_rank_promotion = "raise"
|
||||
f(x)
|
||||
self.assertGreaterEqual(num_traces, 2)
|
||||
nt = num_traces
|
||||
f(x)
|
||||
self.assertEqual(num_traces, nt + 1)
|
||||
f(x)
|
||||
self.assertEqual(num_traces, nt + 1)
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = allow_promotion
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -4226,6 +4232,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_ = api.linearize(partial(f, core.unit), 3.)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scalar_literals(self):
|
||||
@ -4369,6 +4376,7 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertLen(jaxpr.eqns, 0)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -5311,7 +5319,7 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
for i in range(2):
|
||||
buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i)
|
||||
buf = jnp.matmul(bd_rowsum, buf)
|
||||
return buf * val
|
||||
return buf * val[None, :]
|
||||
# -----
|
||||
# Vertorizing will raise shape error
|
||||
bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec)
|
||||
@ -5343,6 +5351,7 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertEqual(shape, ())
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -6311,6 +6320,7 @@ def transpose_unary(f, x_example):
|
||||
return transposed
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_linear_call(self):
|
||||
@ -6639,6 +6649,7 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomVmapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -7023,6 +7034,7 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomApiTest(jtu.JaxTestCase):
|
||||
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
|
||||
|
||||
@ -7060,6 +7072,7 @@ class CustomApiTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(getattr(f, method), Callable)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.ignore_warning(message="Values that an @invertible function closes")
|
||||
@ -7168,6 +7181,7 @@ class InvertibleADTest(jtu.JaxTestCase):
|
||||
check_dtypes=True)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
@ -7190,6 +7204,7 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
pmap_fun(a) # doesn't crash
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NamedCallTest(jtu.JaxTestCase):
|
||||
|
||||
def test_default_name(self):
|
||||
@ -7270,6 +7285,7 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
self.assertRaises(OverflowError, f, int_min - 1)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class BackendsTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not sys.executable, "test requires sys.executable")
|
||||
@ -7292,6 +7308,7 @@ class BackendsTest(jtu.JaxTestCase):
|
||||
assert "No GPU/TPU found" not in result.stderr.decode()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CleanupTest(jtu.JaxTestCase):
|
||||
def test_call_wrapped_second_phase_cleanup(self):
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user