* Don't perform size 0 slices into scipy rotations.
This is disallowed by scipy after https://github.com/scipy/scipy/pull/21776.
(cherry picked from commit 834e71bbe1e4e23de8f658e7380c385be7c5099a)
* test: change random matrix generation for Rotation
(cherry picked from commit 4d433d063e2274d2998913d615aa168889a91b9a)
* Skip lax_scipy_special_functions_test as it will be fixed in the next JAX version
---------
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This matches scipy behavior as of 1.11.
I also went through the tests and enabled a bunch of disabled tests which appear to pass now(?).
PiperOrigin-RevId: 655719643
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.
Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014
The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.
The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0