65 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
5ed2f4ef1c Remove checks for jaxlib v0.4.33 in tests 2024-10-11 15:39:24 -04:00
Dan Foreman-Mackey
ff1c2ac152 Add a test for 64-bit precision of IFFT on GPU.
Fixes https://github.com/jax-ml/jax/issues/23827. The precision fix was in https://github.com/openxla/xla/pull/17598, which has now been integrated into JAX, but we add a test here based on the repro from https://github.com/jax-ml/jax/issues/23827.

PiperOrigin-RevId: 680633622
2024-09-30 10:38:16 -07:00
Peter Hawkins
061f435b73 Bump test tolerance on FFT test that started failing in CI after an XLA change.
PiperOrigin-RevId: 679715691
2024-09-27 13:49:58 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
vfdev-5
bb1fb3ba45 Follow-up to #22736
On adding  device kwarg to jnp.fft.fftfreq and jnp.fft.rfftfreq
2024-07-30 05:39:19 +02:00
Meekail Zain
1c844aebca Updated 2024-02-05 18:01:48 -05:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
4e1b8fcdd2 Check dtypes in fft_p's abstract eval rule.
In particular, this catches a bad error when a bfloat16 is passed to rfft.
2023-10-06 08:04:01 -04:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
3bcba4ade9 Add shape checks for lax.fft.
Fixes https://github.com/google/jax/issues/4734
2022-08-13 16:32:52 +00:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Jake VanderPlas
297a2969a5 [x64] make fft functionality compatible with strict dtype promotion 2022-06-15 10:10:44 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Jake VanderPlas
5a96c0cb18 Skip test outside x64 2022-04-04 16:00:18 -07:00
Peter Hawkins
71a5eb263b [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Yin Li
c5d4aba2a9 Fix fft dtype for norm='ortho' 2022-03-10 10:39:52 -05:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Peter Hawkins
84bccb2420 Support string fft_type values in lax.fft. 2022-02-03 08:52:38 -05:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
iollo jacopo
67dc16fc24 add fft normalisation 2021-10-20 22:15:35 +01:00
Peter Hawkins
efdc3cc794 [JAX] Fix more pylint errors.
* duplicate-string-formatting-argument: use f-strings.
* logging-format-interpolation: use interpolation. Some of these are real but minor performance problems.
* bad-string-format-type: don't use the wrong format type.

PiperOrigin-RevId: 400843759
2021-10-04 16:37:15 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Reza Rahimi
b44d35664c change skip_on_devices to handle device tags 2021-07-30 19:17:21 +00:00
Joost van Doorn
7091ae5af6 Add support for padding and cropping to fft 2021-04-17 08:38:24 +02:00
George Necula
5fce79ed44
Update fft_test.py
Trivial change to re-trigger the presubmit checks
2021-03-29 14:15:27 +03:00
Stephan Hoyer
bc2e42807e Fix transpose rule for jnp.fft.irfft
Fixes #6223
2021-03-25 18:49:38 -07:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Jake VanderPlas
1a83bb6f90 Cleanup: remove remaining instances of rng_factory boilerplate 2020-12-11 13:47:46 -08:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Alex Dragan
412b9d5209
hfft and ihfft implementation (#3664) 2020-07-10 10:34:59 -07:00
Jake Vanderplas
0a6b715cd4
Add _NOT_IMPLEMENTED attribute to jax.numpy (fixes #3689) (#3698) 2020-07-09 16:31:08 -07:00
Jake Vanderplas
19adce595c
Cleanup: use test_util dtypes where possible (#3695)
* Cleanup: use test_util dtypes where possible

* fix issue in fft test

* fix duplicate test name issue
2020-07-08 13:21:48 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis (#3304)
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
Peter Hawkins
dc4761c72a
Fix type promotion for real FFTs. (#3300)
Only enable gradient test in x64 mode.
2020-06-02 17:04:52 -04:00
Peter Hawkins
a06b122e4a
Add support for 64-bit FFTs. (#3290) 2020-06-02 09:41:44 -04:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Jake Vanderplas
2ad425d9ed
Fix coverage of axis argument in fft_test (#3274) 2020-06-01 10:48:04 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00