mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #9330 from jakevdp:rank-promotion-final
PiperOrigin-RevId: 427878821
This commit is contained in:
commit
8b4a7ce910
@ -29,6 +29,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
`dialect=` is passed.
|
||||
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
|
||||
`ir.Module` object instead of its string representation.
|
||||
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
|
||||
default. To recover the previous behavior, use the `jax.test_util.with_config`
|
||||
decorator:
|
||||
```python
|
||||
@jtu.with_config(jax_numpy_rank_promotion='allow')
|
||||
class MyTest(jtu.JaxTestCase):
|
||||
...
|
||||
```
|
||||
|
||||
## jaxlib 0.1.76 (Jan 27, 2022)
|
||||
|
||||
|
@ -915,7 +915,10 @@ def with_config(**kwds):
|
||||
|
||||
class JaxTestCase(parameterized.TestCase):
|
||||
"""Base class for JAX tests including numerical checks and boilerplate."""
|
||||
_default_config = {'jax_enable_checks': True}
|
||||
_default_config = {
|
||||
'jax_enable_checks': True,
|
||||
'jax_numpy_rank_promotion': 'raise',
|
||||
}
|
||||
|
||||
# TODO(mattjj): this obscures the error messages from failures, figure out how
|
||||
# to re-enable it
|
||||
|
@ -150,6 +150,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
|
||||
return f1(*args1)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="allow")
|
||||
class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -70,7 +70,6 @@ python_version = (sys.version_info[0], sys.version_info[1])
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
"""Shared tests between the Python and the C++ jax,jit implementations.
|
||||
|
||||
@ -860,7 +859,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
python_should_be_executing = False
|
||||
self.assertEqual(x, f(x))
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
@property
|
||||
@ -868,7 +867,6 @@ class PythonJitTest(CPPJitTest):
|
||||
return api._python_jit
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_item(self):
|
||||
@ -3416,7 +3414,6 @@ class APITest(jtu.JaxTestCase):
|
||||
FLAGS.jax_numpy_rank_promotion = allow_promotion
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -4273,7 +4270,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_ = api.linearize(partial(f, core.unit), 3.)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scalar_literals(self):
|
||||
@ -4417,7 +4413,6 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertLen(jaxpr.eqns, 0)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -5392,7 +5387,6 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertEqual(shape, ())
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -6361,7 +6355,6 @@ def transpose_unary(f, x_example):
|
||||
return transposed
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_linear_call(self):
|
||||
@ -6690,7 +6683,6 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomVmapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@ -7117,7 +7109,6 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(str(jaxpr), str(jaxpr_ref))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CustomApiTest(jtu.JaxTestCase):
|
||||
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
|
||||
|
||||
@ -7155,7 +7146,6 @@ class CustomApiTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(getattr(f, method), Callable)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.ignore_warning(message="Values that an @invertible function closes")
|
||||
@ -7264,7 +7254,6 @@ class InvertibleADTest(jtu.JaxTestCase):
|
||||
check_dtypes=True)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
@ -7287,7 +7276,6 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
pmap_fun(a) # doesn't crash
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NamedCallTest(jtu.JaxTestCase):
|
||||
|
||||
def test_default_name(self):
|
||||
@ -7368,7 +7356,6 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
self.assertRaises(OverflowError, f, int_min - 1)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class BackendsTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not sys.executable, "test requires sys.executable")
|
||||
@ -7391,7 +7378,6 @@ class BackendsTest(jtu.JaxTestCase):
|
||||
assert "No GPU/TPU found" not in result.stderr.decode()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class CleanupTest(jtu.JaxTestCase):
|
||||
def test_call_wrapped_second_phase_cleanup(self):
|
||||
try:
|
||||
@ -7552,6 +7538,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0])
|
||||
self.assertEqual(4, jaxpr.out_avals[0].shape[1])
|
||||
|
||||
@jax.numpy_rank_promotion("allow") # explicitly exercises implicit rank promotion.
|
||||
def test_basic_batchpoly_neuralnet(self):
|
||||
def predict(params, inputs):
|
||||
for W, b in params:
|
||||
|
@ -40,7 +40,6 @@ config.parse_flags_with_absl()
|
||||
# These are 'manual' tests for batching (vmap). The more exhaustive, more
|
||||
# systematic tests are in lax_test.py's LaxVmapTest class.
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
def testConstantFunction(self):
|
||||
|
@ -93,7 +93,6 @@ def _zero_for_irfft(z, axes):
|
||||
return jnp.concatenate(parts, axis=axis)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class FftTest(jtu.JaxTestCase):
|
||||
|
||||
def testNotImplemented(self):
|
||||
|
@ -26,7 +26,6 @@ import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.
|
||||
|
@ -189,7 +189,7 @@ def check_grads_bilinear(f, args, order,
|
||||
check_grads(lambda rhs: f(lhs, rhs), (rhs,), order,
|
||||
modes=modes, atol=atol, rtol=rtol, eps=1.)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
|
||||
class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
|
@ -30,7 +30,6 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class EinsumTest(jtu.JaxTestCase):
|
||||
|
||||
def _check(self, s, *ops):
|
||||
|
@ -414,7 +414,6 @@ MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
|
||||
MODES = ["clip", "drop", "promise_in_bounds"]
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class IndexingTest(jtu.JaxTestCase):
|
||||
"""Tests for Numpy indexing translation rules."""
|
||||
|
||||
@ -997,7 +996,6 @@ def _update_tol(op):
|
||||
return tol
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
|
@ -519,7 +519,6 @@ def _promote_like_jnp(fun, inexact=False):
|
||||
return wrapper
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed Numpy implementation."""
|
||||
|
||||
@ -5915,7 +5914,6 @@ GRAD_SPECIAL_VALUE_TEST_RECORDS = [
|
||||
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
|
||||
]
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NumpyGradTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
@ -6020,7 +6018,6 @@ class NumpyGradTests(jtu.JaxTestCase):
|
||||
tol = 3e-2
|
||||
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
|
||||
def testWrappedSignaturesMatch(self):
|
||||
@ -6136,7 +6133,6 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]:
|
||||
yield arg_dtypes
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -6168,7 +6164,6 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
# that jnp returns float32. e.g. np.cos(np.uint8(0))
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NumpyDocTests(jtu.JaxTestCase):
|
||||
|
||||
def test_lax_numpy_docstrings(self):
|
||||
|
@ -25,7 +25,6 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class VectorizeTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -64,7 +64,6 @@ def rand_sym_pos_def(rng, shape, dtype):
|
||||
return matrix @ matrix.T.conj()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
def _fetch_preconditioner(self, preconditioner, A, rng=None):
|
||||
|
@ -144,7 +144,6 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
]
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed Scipy implementation."""
|
||||
|
||||
|
@ -181,7 +181,6 @@ LAX_OPS = [
|
||||
]
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxTest(jtu.JaxTestCase):
|
||||
"""Numerical tests for LAX operations."""
|
||||
|
||||
@ -2669,7 +2668,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
# check casting to ndarray works
|
||||
@ -2872,7 +2870,6 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
np.log1p(np.float32(1e-5)), lax.log1p(np.complex64(1e-5)))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_abstract_eval(self):
|
||||
|
@ -42,7 +42,6 @@ float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion='raise')
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
def testNotImplemented(self):
|
||||
@ -957,7 +956,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertFalse(np.any(np.isnan(cube_func(a))))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion='raise')
|
||||
class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1374,7 +1372,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion='raise')
|
||||
|
||||
class LaxLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
def run_test(self, alpha, beta):
|
||||
|
@ -35,7 +35,6 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testSoftplusGrad(self):
|
||||
@ -230,7 +229,7 @@ INITIALIZER_RECS = [
|
||||
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
|
||||
]
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
|
||||
class NNInitializersTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
|
@ -74,7 +74,6 @@ def check_1d_2d_mesh(f, set_mesh):
|
||||
|
||||
|
||||
# TODO(skye): make the buffer donation utils part of JaxTestCase
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 1)])
|
||||
@ -635,7 +634,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(f(1, 'bye'), 5)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
@ -953,7 +951,6 @@ def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
@check_1d_2d_mesh(set_mesh=True)
|
||||
def testNonDivisibleArgs(self, mesh, resources):
|
||||
@ -1181,7 +1178,6 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
f(x)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
def testOpShardingRoundTrip(self):
|
||||
|
@ -109,7 +109,6 @@ ignore_xmap_warning = partial(
|
||||
jtu.ignore_warning, message=".*is an experimental.*")
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@property
|
||||
@ -1905,7 +1904,6 @@ class CppPmapTest(PythonPmapTest):
|
||||
return src_api._cpp_pmap
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class VmapOfPmapTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(apaszke)
|
||||
@ -1948,7 +1946,6 @@ class VmapOfPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -2134,7 +2131,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
|
||||
def testAllDevices(self):
|
||||
@ -2387,7 +2383,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
jax.grad(mk_case(vmap))(x, y))
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def testThreadsafeIndexing(self):
|
||||
@ -2493,7 +2488,6 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
_ = x[0]
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
|
||||
def testShardsPerAxis(self):
|
||||
@ -2623,7 +2617,6 @@ def _spec_str(spec):
|
||||
f"{spec.mesh_mapping},)")
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class ShardArgsTest(jtu.JaxTestCase):
|
||||
|
||||
def numpy_array(x):
|
||||
|
@ -58,7 +58,6 @@ def _compute_relative_diff(actual, expected):
|
||||
_dot = functools.partial(jnp.dot, precision="highest")
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class QdwhTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -56,7 +56,6 @@ PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
|
||||
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
def testThreefry2x32(self):
|
||||
@ -315,7 +314,6 @@ class PrngTest(jtu.JaxTestCase):
|
||||
lambda: keys[0, 1, None, 2])
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def _CheckCollisions(self, samples, nbits):
|
||||
@ -1226,7 +1224,6 @@ double_threefry_prng_impl = prng.PRNGImpl(
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return prng.seed_with_impl(double_threefry_prng_impl, seed)
|
||||
@ -1255,7 +1252,6 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return random.rbg_key(seed)
|
||||
|
@ -42,7 +42,7 @@ def _get_dctn_test_s(shape, axes):
|
||||
s_list.extend(itertools.product(*[[shape[ax]+i for i in range(-shape[ax]+1, shape[ax]+1)] for ax in axes]))
|
||||
return s_list
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
|
||||
class LaxBackedScipyFftTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.fft implementations"""
|
||||
|
||||
|
@ -57,7 +57,6 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
|
||||
return result
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NdimageTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -64,7 +64,6 @@ def zakharovFromIndices(x, ii):
|
||||
return answer
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class TestBFGS(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -141,7 +140,6 @@ class TestBFGS(jtu.JaxTestCase):
|
||||
jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class TestLBFGS(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -35,7 +35,6 @@ threedim_shapes = [(2, 2, 2), (3, 3, 2), (4, 4, 2), (5, 5, 2)]
|
||||
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.stats implementations"""
|
||||
|
||||
|
@ -39,7 +39,6 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class ShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -277,7 +276,6 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
# TODO(skye): add more error tests
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class ShardedJitErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -300,7 +298,6 @@ class ShardedJitErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
# Tests that don't need a TPU to run.
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class ShardedJitTestNoTpu(jtu.JaxTestCase):
|
||||
|
||||
def testTranslationRule(self):
|
||||
@ -329,7 +326,7 @@ class ShardedJitTestNoTpu(jtu.JaxTestCase):
|
||||
# Annotation from sharded_jit
|
||||
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
|
||||
class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -116,7 +116,6 @@ def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
||||
return _rand_sparse
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class cuSparseTest(jtu.JaxTestCase):
|
||||
def gpu_dense_conversion_warning_context(self, dtype):
|
||||
if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
|
||||
@ -555,7 +554,6 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class BCOOTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
@ -1679,7 +1677,6 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class SparseGradTest(jtu.JaxTestCase):
|
||||
def test_sparse_grad(self):
|
||||
rng_sparse = rand_sparse(self.rng())
|
||||
@ -1702,7 +1699,6 @@ class SparseGradTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class SparseObjectTest(jtu.JaxTestCase):
|
||||
def test_repr(self):
|
||||
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
|
||||
@ -1898,7 +1894,6 @@ class SparseObjectTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(M.sum(), Msp.sum())
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class SparseRandomTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format(
|
||||
|
@ -30,7 +30,6 @@ from jax.experimental.sparse.transform import (
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class SparsifyTest(jtu.JaxTestCase):
|
||||
@classmethod
|
||||
def sparsify(cls, f):
|
||||
|
@ -210,7 +210,6 @@ def schedules(sizes: Dict[str, int]
|
||||
yield axis_resources, mesh_data
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class XMapTestCase(jtu.BufferDonationTestCase):
|
||||
pass
|
||||
|
||||
@ -1178,7 +1177,6 @@ class PDotTests(XMapTestCase):
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
@ -1410,7 +1408,6 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class NamedAutodiffTests(jtu.JaxTestCase):
|
||||
|
||||
def testVjpReduceAxes(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user