Merge pull request #9330 from jakevdp:rank-promotion-final

PiperOrigin-RevId: 427878821
This commit is contained in:
jax authors 2022-02-10 17:08:39 -08:00
commit 8b4a7ce910
29 changed files with 20 additions and 73 deletions

View File

@ -29,6 +29,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
`ir.Module` object instead of its string representation.
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
default. To recover the previous behavior, use the `jax.test_util.with_config`
decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTest(jtu.JaxTestCase):
...
```
## jaxlib 0.1.76 (Jan 27, 2022)

View File

@ -915,7 +915,10 @@ def with_config(**kwds):
class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_config = {'jax_enable_checks': True}
_default_config = {
'jax_enable_checks': True,
'jax_numpy_rank_promotion': 'raise',
}
# TODO(mattjj): this obscures the error messages from failures, figure out how
# to re-enable it

View File

@ -150,6 +150,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
return f1(*args1)
@jtu.with_config(jax_numpy_rank_promotion="allow")
class JaxToTfTestCase(jtu.JaxTestCase):
def setUp(self):

View File

@ -70,7 +70,6 @@ python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CPPJitTest(jtu.BufferDonationTestCase):
"""Shared tests between the Python and the C++ jax,jit implementations.
@ -860,7 +859,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
python_should_be_executing = False
self.assertEqual(x, f(x))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PythonJitTest(CPPJitTest):
@property
@ -868,7 +867,6 @@ class PythonJitTest(CPPJitTest):
return api._python_jit
@jtu.with_config(jax_numpy_rank_promotion="raise")
class APITest(jtu.JaxTestCase):
def test_grad_item(self):
@ -3416,7 +3414,6 @@ class APITest(jtu.JaxTestCase):
FLAGS.jax_numpy_rank_promotion = allow_promotion
@jtu.with_config(jax_numpy_rank_promotion="raise")
class RematTest(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -4273,7 +4270,6 @@ class RematTest(jtu.JaxTestCase):
_ = api.linearize(partial(f, core.unit), 3.)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class JaxprTest(jtu.JaxTestCase):
def test_scalar_literals(self):
@ -4417,7 +4413,6 @@ class JaxprTest(jtu.JaxTestCase):
self.assertLen(jaxpr.eqns, 0)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomJVPTest(jtu.JaxTestCase):
def test_basic(self):
@ -5392,7 +5387,6 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertEqual(shape, ())
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomVJPTest(jtu.JaxTestCase):
def test_basic(self):
@ -6361,7 +6355,6 @@ def transpose_unary(f, x_example):
return transposed
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomTransposeTest(jtu.JaxTestCase):
def test_linear_call(self):
@ -6690,7 +6683,6 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomVmapTest(jtu.JaxTestCase):
def test_basic(self):
@ -7117,7 +7109,6 @@ class CustomVmapTest(jtu.JaxTestCase):
self.assertEqual(str(jaxpr), str(jaxpr_ref))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
@ -7155,7 +7146,6 @@ class CustomApiTest(jtu.JaxTestCase):
self.assertIsInstance(getattr(f, method), Callable)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class InvertibleADTest(jtu.JaxTestCase):
@jtu.ignore_warning(message="Values that an @invertible function closes")
@ -7264,7 +7254,6 @@ class InvertibleADTest(jtu.JaxTestCase):
check_dtypes=True)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class BufferDonationTest(jtu.BufferDonationTestCase):
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
@ -7287,7 +7276,6 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
pmap_fun(a) # doesn't crash
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NamedCallTest(jtu.JaxTestCase):
def test_default_name(self):
@ -7368,7 +7356,6 @@ class NamedCallTest(jtu.JaxTestCase):
self.assertRaises(OverflowError, f, int_min - 1)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class BackendsTest(jtu.JaxTestCase):
@unittest.skipIf(not sys.executable, "test requires sys.executable")
@ -7391,7 +7378,6 @@ class BackendsTest(jtu.JaxTestCase):
assert "No GPU/TPU found" not in result.stderr.decode()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class CleanupTest(jtu.JaxTestCase):
def test_call_wrapped_second_phase_cleanup(self):
try:
@ -7552,6 +7538,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0])
self.assertEqual(4, jaxpr.out_avals[0].shape[1])
@jax.numpy_rank_promotion("allow") # explicitly exercises implicit rank promotion.
def test_basic_batchpoly_neuralnet(self):
def predict(params, inputs):
for W, b in params:

View File

@ -40,7 +40,6 @@ config.parse_flags_with_absl()
# These are 'manual' tests for batching (vmap). The more exhaustive, more
# systematic tests are in lax_test.py's LaxVmapTest class.
@jtu.with_config(jax_numpy_rank_promotion="raise")
class BatchingTest(jtu.JaxTestCase):
def testConstantFunction(self):

View File

@ -93,7 +93,6 @@ def _zero_for_irfft(z, axes):
return jnp.concatenate(parts, axis=axis)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class FftTest(jtu.JaxTestCase):
def testNotImplemented(self):

View File

@ -26,7 +26,6 @@ import numpy as np
config.parse_flags_with_absl()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class InfeedTest(jtu.JaxTestCase):
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.

View File

@ -189,7 +189,7 @@ def check_grads_bilinear(f, args, order,
check_grads(lambda rhs: f(lhs, rhs), (rhs,), order,
modes=modes, atol=atol, rtol=rtol, eps=1.)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxAutodiffTest(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(

View File

@ -30,7 +30,6 @@ from jax.config import config
config.parse_flags_with_absl()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class EinsumTest(jtu.JaxTestCase):
def _check(self, s, *ops):

View File

@ -414,7 +414,6 @@ MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
MODES = ["clip", "drop", "promise_in_bounds"]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""
@ -997,7 +996,6 @@ def _update_tol(op):
return tol
@jtu.with_config(jax_numpy_rank_promotion="raise")
class IndexedUpdateTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({

View File

@ -519,7 +519,6 @@ def _promote_like_jnp(fun, inexact=False):
return wrapper
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""
@ -5915,7 +5914,6 @@ GRAD_SPECIAL_VALUE_TEST_RECORDS = [
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyGradTests(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(
@ -6020,7 +6018,6 @@ class NumpyGradTests(jtu.JaxTestCase):
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpySignaturesTest(jtu.JaxTestCase):
def testWrappedSignaturesMatch(self):
@ -6136,7 +6133,6 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]:
yield arg_dtypes
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyUfuncTests(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -6168,7 +6164,6 @@ class NumpyUfuncTests(jtu.JaxTestCase):
# that jnp returns float32. e.g. np.cos(np.uint8(0))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyDocTests(jtu.JaxTestCase):
def test_lax_numpy_docstrings(self):

View File

@ -25,7 +25,6 @@ from jax.config import config
config.parse_flags_with_absl()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class VectorizeTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -64,7 +64,6 @@ def rand_sym_pos_def(rng, shape, dtype):
return matrix @ matrix.T.conj()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyTests(jtu.JaxTestCase):
def _fetch_preconditioner(self, preconditioner, A, rng=None):

View File

@ -144,7 +144,6 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Scipy implementation."""

View File

@ -181,7 +181,6 @@ LAX_OPS = [
]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxTest(jtu.JaxTestCase):
"""Numerical tests for LAX operations."""
@ -2669,7 +2668,6 @@ class LaxTest(jtu.JaxTestCase):
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
# check casting to ndarray works
@ -2872,7 +2870,6 @@ class LazyConstantTest(jtu.JaxTestCase):
np.log1p(np.float32(1e-5)), lax.log1p(np.complex64(1e-5)))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxNamedShapeTest(jtu.JaxTestCase):
def test_abstract_eval(self):

View File

@ -42,7 +42,6 @@ float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
@jtu.with_config(jax_numpy_rank_promotion='raise')
class NumpyLinalgTest(jtu.JaxTestCase):
def testNotImplemented(self):
@ -957,7 +956,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertFalse(np.any(np.isnan(cube_func(a))))
@jtu.with_config(jax_numpy_rank_promotion='raise')
class ScipyLinalgTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
@ -1374,7 +1372,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)
@jtu.with_config(jax_numpy_rank_promotion='raise')
class LaxLinalgTest(jtu.JaxTestCase):
def run_test(self, alpha, beta):

View File

@ -35,7 +35,6 @@ from jax.config import config
config.parse_flags_with_absl()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NNFunctionsTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
@ -230,7 +229,7 @@ INITIALIZER_RECS = [
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NNInitializersTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":

View File

@ -74,7 +74,6 @@ def check_1d_2d_mesh(f, set_mesh):
# TODO(skye): make the buffer donation utils part of JaxTestCase
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PJitTest(jtu.BufferDonationTestCase):
@jtu.with_mesh([('x', 1)])
@ -635,7 +634,6 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertEqual(f(1, 'bye'), 5)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
@ -953,7 +951,6 @@ def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PJitErrorTest(jtu.JaxTestCase):
@check_1d_2d_mesh(set_mesh=True)
def testNonDivisibleArgs(self, mesh, resources):
@ -1181,7 +1178,6 @@ class PJitErrorTest(jtu.JaxTestCase):
f(x)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class UtilTest(jtu.JaxTestCase):
def testOpShardingRoundTrip(self):

View File

@ -109,7 +109,6 @@ ignore_xmap_warning = partial(
jtu.ignore_warning, message=".*is an experimental.*")
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PythonPmapTest(jtu.JaxTestCase):
@property
@ -1905,7 +1904,6 @@ class CppPmapTest(PythonPmapTest):
return src_api._cpp_pmap
@jtu.with_config(jax_numpy_rank_promotion="raise")
class VmapOfPmapTest(jtu.JaxTestCase):
# TODO(apaszke)
@ -1948,7 +1946,6 @@ class VmapOfPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class VmapPmapCollectivesTest(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -2134,7 +2131,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PmapWithDevicesTest(jtu.JaxTestCase):
def testAllDevices(self):
@ -2387,7 +2383,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
jax.grad(mk_case(vmap))(x, y))
@jtu.with_config(jax_numpy_rank_promotion="raise")
class ShardedDeviceArrayTest(jtu.JaxTestCase):
def testThreadsafeIndexing(self):
@ -2493,7 +2488,6 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
_ = x[0]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class SpecToIndicesTest(jtu.JaxTestCase):
def testShardsPerAxis(self):
@ -2623,7 +2617,6 @@ def _spec_str(spec):
f"{spec.mesh_mapping},)")
@jtu.with_config(jax_numpy_rank_promotion="raise")
class ShardArgsTest(jtu.JaxTestCase):
def numpy_array(x):

View File

@ -58,7 +58,6 @@ def _compute_relative_diff(actual, expected):
_dot = functools.partial(jnp.dot, precision="highest")
@jtu.with_config(jax_numpy_rank_promotion="raise")
class QdwhTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -56,7 +56,6 @@ PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PrngTest(jtu.JaxTestCase):
def testThreefry2x32(self):
@ -315,7 +314,6 @@ class PrngTest(jtu.JaxTestCase):
lambda: keys[0, 1, None, 2])
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomTest(jtu.JaxTestCase):
def _CheckCollisions(self, samples, nbits):
@ -1226,7 +1224,6 @@ double_threefry_prng_impl = prng.PRNGImpl(
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return prng.seed_with_impl(double_threefry_prng_impl, seed)
@ -1255,7 +1252,6 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return random.rbg_key(seed)

View File

@ -42,7 +42,7 @@ def _get_dctn_test_s(shape, axes):
s_list.extend(itertools.product(*[[shape[ax]+i for i in range(-shape[ax]+1, shape[ax]+1)] for ax in axes]))
return s_list
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyFftTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.fft implementations"""

View File

@ -57,7 +57,6 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
return result
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NdimageTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -64,7 +64,6 @@ def zakharovFromIndices(x, ii):
return answer
@jtu.with_config(jax_numpy_rank_promotion="raise")
class TestBFGS(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
@ -141,7 +140,6 @@ class TestBFGS(jtu.JaxTestCase):
jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
@jtu.with_config(jax_numpy_rank_promotion="raise")
class TestLBFGS(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -35,7 +35,6 @@ threedim_shapes = [(2, 2, 2), (3, 3, 2), (4, 4, 2), (5, 5, 2)]
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipySignalTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""

View File

@ -39,7 +39,6 @@ from jax.config import config
config.parse_flags_with_absl()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class ShardedJitTest(jtu.JaxTestCase):
def setUp(self):
@ -277,7 +276,6 @@ class ShardedJitTest(jtu.JaxTestCase):
# TODO(skye): add more error tests
@jtu.with_config(jax_numpy_rank_promotion="raise")
class ShardedJitErrorsTest(jtu.JaxTestCase):
def setUp(self):
@ -300,7 +298,6 @@ class ShardedJitErrorsTest(jtu.JaxTestCase):
# Tests that don't need a TPU to run.
@jtu.with_config(jax_numpy_rank_promotion="raise")
class ShardedJitTestNoTpu(jtu.JaxTestCase):
def testTranslationRule(self):
@ -329,7 +326,7 @@ class ShardedJitTestNoTpu(jtu.JaxTestCase):
# Annotation from sharded_jit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PmapOfShardedJitTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -116,7 +116,6 @@ def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
return _rand_sparse
@jtu.with_config(jax_numpy_rank_promotion="raise")
class cuSparseTest(jtu.JaxTestCase):
def gpu_dense_conversion_warning_context(self, dtype):
if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
@ -555,7 +554,6 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class BCOOTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
@ -1679,7 +1677,6 @@ class BCOOTest(jtu.JaxTestCase):
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparseGradTest(jtu.JaxTestCase):
def test_sparse_grad(self):
rng_sparse = rand_sparse(self.rng())
@ -1702,7 +1699,6 @@ class SparseGradTest(jtu.JaxTestCase):
self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparseObjectTest(jtu.JaxTestCase):
def test_repr(self):
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
@ -1898,7 +1894,6 @@ class SparseObjectTest(jtu.JaxTestCase):
self.assertArraysEqual(M.sum(), Msp.sum())
@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparseRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format(

View File

@ -30,7 +30,6 @@ from jax.experimental.sparse.transform import (
config.parse_flags_with_absl()
@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparsifyTest(jtu.JaxTestCase):
@classmethod
def sparsify(cls, f):

View File

@ -210,7 +210,6 @@ def schedules(sizes: Dict[str, int]
yield axis_resources, mesh_data
@jtu.with_config(jax_numpy_rank_promotion="raise")
class XMapTestCase(jtu.BufferDonationTestCase):
pass
@ -1178,7 +1177,6 @@ class PDotTests(XMapTestCase):
self.assertAllClose(out, expected, check_dtypes=True)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class XMapErrorTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 2)])
@ -1410,7 +1408,6 @@ class XMapErrorTest(jtu.JaxTestCase):
xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message
@jtu.with_config(jax_numpy_rank_promotion="raise")
class NamedAutodiffTests(jtu.JaxTestCase):
def testVjpReduceAxes(self):