This matches scipy behavior as of 1.11.
I also went through the tests and enabled a bunch of disabled tests which appear to pass now(?).
PiperOrigin-RevId: 655719643
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.
Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014
The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.
The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0