This version, proposed by @dfm, does not have a custom JVP for the whole
logsumexp and instead fixes#22398 directly.
Reverts e416c6675acfd82866a6e83e8c221640c4d02f29
PiperOrigin-RevId: 660438802
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
We have some differences between Triton codegen and other fusion codegen,
namely for Remainder/Fmod and Cbrt. Unify that.
- Remove two unused math functions.
- Add mapping from kRemainder to kFmod.
- Use kCbrt device function in elemental_ir_emitter.
PiperOrigin-RevId: 567274915
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
The goal is to ensure that all shards fit into a medium timeout in sanitizer
configurations.
Running 256 entry vectors in spectral_dac is too slow, so let's replace that
with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires
us to widen the tolerance a bit due to vectorization changes.
While here, specify deps a little more precisely as well.
PiperOrigin-RevId: 514440062
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.
This allows them to run in ASAN in more situations.
While here, specify deps a little more precisely as well.
PiperOrigin-RevId: 511829646
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816