The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.
This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
At `float16` precision, one LAX reduction test was found to be flaky, and disabled in https://github.com/jax-ml/jax/pull/25443. This change re-enables that test with a slightly relaxed tolerance instead.
PiperOrigin-RevId: 706771186
Numpy recently merged support for the 2023.12 revision of the Array API:
https://github.com/numpy/numpy/pull/26724
This breaks two of our tests:
1. The first breakage was caused by differences in how numpy and JAX
cast negative floats to `uint8`. Specifically
`np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
`jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
We don't make any promises about consistency with casting floats to
ints, noting that this can even be backend dependent. To fix our
test, we now only generate positive inputs when the output dtype is
unsigned.
2. The second failure was caused by the fact that the approach we took
in #20550 to support backwards compatibility and the Array API for
`clip` differs from the one used in numpy/numpy#26724. Again, the
behavior is consistent, but it produces a different signature. I've
skipped checking `clip`'s signature, but we should revisit it once
the `a_min` and `a_max` parameters have been removed from JAX.
Fixes#22251
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
PiperOrigin-RevId: 572587137
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.
Fixes#15068
Co-authored by: Matthew Johnson <mattjj@google.com>
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.
Fixes https://github.com/google/jax/issues/12008
PiperOrigin-RevId: 488749199
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670