It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.
Happily, this also allows us to remove a number of TPU special cases for the same reason.
PiperOrigin-RevId: 539101155
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Improve JAX test PRNG APIs to fix correlations between test cases.
In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.
This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.
* Fix some failing tests.
* Fix more test failures.
Simplify ediff1d implementation and make it more permissive when casting.
* Relax test tolerance of laplace CDF test.
* Fix some missing cases of broadcasting in np.einsum.
In particular, np.einsum allows one side of a batch or contracting dimension to have size 1 even if the other side has a non-1 size.
Implement np.matmul in terms of np.einsum. This allows us to reuse einsum's logic for performing broadcasting without explicitly broadcasting the LHS and RHS together.
* Add regression test.
Fixes#2189.
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.
Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.
Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').
Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).
We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.
* Fix test tolerances.
* Fix dtype canonicalization for test tolerances.
* Relax core test_vjp tolerance.
The error was that `lhs_names` and `rhs_names` included `batch_names` as
prefixes, but the reshaping logic was written as if they did not include
batch_names (and so batch_names had to be prepended).