44 Commits

Author SHA1 Message Date
Jake VanderPlas
b5e7b60d6a jax.numpy reductions: avoid upcast of f16 when dtype is specified by user 2025-02-12 11:49:39 -08:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Dan Foreman-Mackey
ed4e9823b1 Relax tolerance for LAX reduction test in float16.
At `float16` precision, one LAX reduction test was found to be flaky, and disabled in https://github.com/jax-ml/jax/pull/25443. This change re-enables that test with a slightly relaxed tolerance instead.

PiperOrigin-RevId: 706771186
2024-12-16 11:14:23 -08:00
Nitin Srinivasan
ecc2673e7b Disable failing test cases when JAX_ENABLE_X64=1 in the Bazel CPU build
PiperOrigin-RevId: 705635799
2024-12-12 14:41:52 -08:00
Jake VanderPlas
29a8cce66c jax.numpy: require boolean dtype for where argument 2024-12-05 09:27:19 -08:00
Peter Hawkins
f95417006f [tpu] Disable a cumulative reduction test on TPU v6e that currently hits an unimplemented case in XLA.
PiperOrigin-RevId: 692979420
2024-11-04 08:39:46 -08:00
Peter Hawkins
a8f44c4700 Fix a CI failure under NumPy 2.1.
PiperOrigin-RevId: 691428702
2024-10-30 08:30:25 -07:00
Jake VanderPlas
02daf75f97 Add new jnp.cumulative_prod function.
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
2c221f2d5a Register several jax.numpy argument name deprecations 2024-08-22 09:41:53 -07:00
Peter Hawkins
323e257f67 Fix test failures.
PiperOrigin-RevId: 662703221
2024-08-13 17:02:14 -07:00
Dan Foreman-Mackey
9e9acc9ecc Fix compatibility with nightly numpy
Numpy recently merged support for the 2023.12 revision of the Array API:
https://github.com/numpy/numpy/pull/26724

This breaks two of our tests:

1. The first breakage was caused by differences in how numpy and JAX
   cast negative floats to `uint8`. Specifically
   `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
   `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
   We don't make any promises about consistency with casting floats to
   ints, noting that this can even be backend dependent. To fix our
   test, we now only generate positive inputs when the output dtype is
   unsigned.

2. The second failure was caused by the fact that the approach we took
   in #20550 to support backwards compatibility and the Array API for
   `clip` differs from the one used in numpy/numpy#26724. Again, the
   behavior is consistent, but it produces a different signature. I've
   skipped checking `clip`'s signature, but we should revisit it once
   the `a_min` and `a_max` parameters have been removed from JAX.

Fixes #22251
2024-07-03 11:07:58 -04:00
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
jax authors
bab7f40dec Merge pull request #21262 from vfdev-5:depr-change-ddof-to-correction-21088
PiperOrigin-RevId: 636949170
2024-05-24 09:47:27 -07:00
vfdev-5
55f8284e27 Added correction arg in jnp.var and jnp.std
Description:
- Added correction arg in jnp.var and jnp.std
- Addresses https://github.com/google/jax/issues/21088
- Updated signatures in init.pyi
- Updated tests
2024-05-24 16:16:12 +00:00
vfdev-5
3c201e0b8c jnp.var returns nan if N-ddof <= 0
Description:
- Updated jnp.var function to explicitly return np.nan if normalizer is non-positive
- Added a test for jnp.var and jnp.std

Fixed #21330
2024-05-24 13:45:59 +00:00
Meekail Zain
34c5163fd2 Refactored common upcast for integral-type accumulators 2024-05-06 15:13:10 +00:00
Meekail Zain
30cd3b88fd Add support for copy kwarg in astype to match Array API 2024-04-22 16:25:37 +00:00
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
Hongyu Chiu
c295b1655c Fix axis=None and keepdims=True for jnp.quantile and jnp.median
Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median`

Remove `print`

Update tests
2024-03-15 09:05:02 +08:00
Jake VanderPlas
9b46e2d6a3 Support float8 in reduce_min & reduce_max 2023-12-18 13:37:45 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Jake VanderPlas
b18ca05bc6 jnp.mean: for f16 inputs, accumulate in f32 2023-09-27 18:27:19 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
Jake VanderPlas
62fb0cd8a2 explicitly convert jnp.var scalar normalizer to float (from int)
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.

Fixes #15068

Co-authored by: Matthew Johnson <mattjj@google.com>
2023-05-23 09:44:08 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Jake VanderPlas
6f8885a0c2 lax_numpy: move quantile-based functions to reductions.py 2023-03-23 16:39:20 -07:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Jake VanderPlas
58323d5b40 jax.numpy reductions: better validation of initial value 2023-02-13 08:43:25 -08:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Jake VanderPlas
b037feb105 [x64] more type safety for lax_numpy-related tests 2022-12-01 11:18:02 -08:00
Peter Hawkins
99e1c3dd66 [JAX] Opt into high precision matrix multiplications in JAX tests that fail on A100.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.

Fixes https://github.com/google/jax/issues/12008

PiperOrigin-RevId: 488749199
2022-11-15 13:50:21 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Jake VanderPlas
32ef3ba37b jnp.average: support tuple axis 2022-10-06 10:20:46 -07:00
Jake VanderPlas
439217644a Split parts of lax_numpy_test.py into separate test files.
Why? The main test file is getting too big and this hinders iteration on individual tests

PiperOrigin-RevId: 478130215
2022-09-30 19:38:11 -07:00