43 Commits

Author SHA1 Message Date
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
3ca7d67e8d Fully implement and test axes argument to jax.scipy.signal.fftconvolve
PiperOrigin-RevId: 523707411
2023-04-12 08:31:30 -07:00
Jake VanderPlas
cc7fc2e0af fftconvolve: adjust test tolerances
PiperOrigin-RevId: 523695809
2023-04-12 07:38:19 -07:00
Jake VanderPlas
d0ed619101 jax.scipy.signal.convolve: support method='fft' 2023-04-10 14:54:15 -07:00
jax authors
2ebb178c35 Merge pull request #15224 from jecampagne:fftconvolve2dr
PiperOrigin-RevId: 522671725
2023-04-07 13:29:50 -07:00
Jean-Eric Campagne
4beee13ba0 Add implementation of jax.scipy.fftconvolve 2023-04-07 17:19:08 +02:00
Peter Hawkins
eb80b17762 Increase precision of detrend test on TPU.
The test appears to pass at the higher tolerance these days.

PiperOrigin-RevId: 515474890
2023-03-09 16:32:40 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Jake VanderPlas
4714a5cc8f Add regression test for #12920 2022-10-21 12:52:32 -07:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
Jake VanderPlas
03f2189f90 [x64] make jax.scipy.signal compatible with strict dtype promotion.
Also a fair bit of cleanup & refactoring of related code.
2022-06-21 09:28:46 -07:00
Rohit Santhanam
8d9f17df19 Disabled one and enabled several unit tests for ROCm. 2022-05-10 19:47:26 +00:00
Peter Hawkins
be9aac1dd3 Relax test tolerance for flaky test.
PiperOrigin-RevId: 445961557
2022-05-02 10:08:17 -07:00
Tianjian Lu
cdd1167095 [signal] Update signal detrend test.
PiperOrigin-RevId: 445253797
2022-04-28 14:47:07 -07:00
Yotaro Kubo
a7fd751acf Add istft to jax.scipy.signal. 2022-04-01 14:28:53 +09:00
Yotaro Kubo
2e70177385 Fix a bug in fft helper appears when nperseg=1. 2022-03-28 14:51:54 +09:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Yotaro Kubo
e085370ec4 Add some functions for spectral analysis.
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
2022-02-17 15:59:24 +09:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
6114e6a0d3 test_util: add decorator to set config values in test cases 2021-08-05 14:06:37 -07:00
Jake VanderPlas
30ea76cb6b disable rank promotion for jax scipy tests 2021-08-04 10:44:23 -07:00
Luke Pfister
c33388b136 Support complex numbers in jax.scipy.signal.convolve/correlate 2021-06-18 13:07:00 -06:00
Peter Hawkins
48393c1def Relax some test tolerances of tests that fail on GPU in -x32 mode. 2021-02-16 22:14:43 -05:00
Jake VanderPlas
898fa7e6f8 cleanup: remove unused test arg 2020-11-10 09:16:44 -08:00
ayush-1506
a3c729b97a Fix #4775 + additional fixes 2020-11-09 10:40:14 +05:30
Jake Vanderplas
512ed18d5a
Cleanup: convert uses of 'import numpy as onp' in tests (#3756) 2020-07-14 13:03:24 -07:00
Jake Vanderplas
6b471e2ac6
Cleanup: define type lists in test_util & use in several test files. (#3616) 2020-07-07 17:01:38 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
33c455a1a8
Add jax.scipy.signal.detrend (#3516) 2020-06-22 19:49:00 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00
Matthew Johnson
fcb6ff0a15 loosen scipy convolve test tolerance (GPU flaky) 2020-04-22 14:39:51 -07:00
Roy Frostig
d906c89e0b fix scipy_signal_test convolve failures 2020-04-16 13:25:52 -07:00
Peter Hawkins
1298e9e8c4
Fix some test failures. (#2713) 2020-04-14 18:23:19 -04:00
Jake VanderPlas
89c9c437f8 Add support for mode=same in convolve2d & correlate2d 2020-04-10 14:11:16 -07:00
Jake VanderPlas
edda69ef83 Add implementations of scipy.signal.convolve & correlate, 1d & 2d 2020-04-10 11:54:10 -07:00