39 Commits

Author SHA1 Message Date
Jake VanderPlas
b4f98eef7e refactor: move scalar type defs out of lax_numpy.py 2025-02-06 14:48:10 -08:00
Jake VanderPlas
23c1d62910 internal: move more NumPy APIs to ensure_arraylike 2025-01-23 08:48:13 -08:00
Dan Foreman-Mackey
c6131ee527 Add support for N-D FFTs with D>3. 2024-12-19 15:23:30 +00:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Jake VanderPlas
162322fc70 Better docs for fftshift & ifftshift 2024-10-04 06:12:02 -07:00
jax authors
aa334145b4 Merge pull request #22958 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 661340981
2024-08-09 11:30:16 -07:00
rajasekharporeddy
ff1f199d09 Improved docs for jnp.fft.rfftn and jnp.fft.irfftn 2024-08-09 23:07:17 +05:30
rajasekharporeddy
6ee1555d21 Fix broken links in jnp.fft.fftfreq and jnp.fft.rfftfreq 2024-08-09 14:23:56 +05:30
rajasekharporeddy
3095c570b8 Better docs for jnp.fft.rfft2 and jnp.fft.irfft2 2024-08-07 17:59:53 +05:30
rajasekharporeddy
1acff9c739 Better docs for jnp.fft.hfft and jnp.fft.ihfft 2024-08-05 21:53:29 +05:30
vfdev-5
bb1fb3ba45 Follow-up to #22736
On adding  device kwarg to jnp.fft.fftfreq and jnp.fft.rfftfreq
2024-07-30 05:39:19 +02:00
Jake VanderPlas
6516a079f8 [array API] add device argument to fftfreq & rfftfreq 2024-07-29 13:23:54 -07:00
rajasekharporeddy
1525f01270 Improved docs for jnp.fft.fft2 and ifft2 2024-07-23 23:51:34 +05:30
rajasekharporeddy
1650d1e8aa Improved docs for jnp.fft.rfft and irfft 2024-07-23 11:33:03 +05:30
rajasekharporeddy
885bedd33c Improved docs for jnp.fft.fft and jnp.fft.ifft 2024-07-17 07:28:52 +05:30
rajasekharporeddy
01ad65c7f1 Improved docs for jnp.fft.ifftn 2024-07-10 20:45:35 +05:30
rajasekharporeddy
6a65707bd2 Improved docs for jnp.fft.fftn 2024-07-09 12:18:23 +05:30
Meekail Zain
1c844aebca Updated 2024-02-05 18:01:48 -05:00
Jake VanderPlas
43a9faa06a Rename _wraps to implements 2024-01-24 14:14:19 -08:00
eukub
5579ff264c сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code 2023-12-28 19:38:13 +03:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Jake VanderPlas
d25a96caea [x64] more type safety in jax.scipy.signal 2022-12-01 13:43:07 -08:00
Jake VanderPlas
2416d15435 Call _check_arraylike for jnp.linalg & jnp.fft functions 2022-10-31 09:19:53 -07:00
Jake VanderPlas
96461998d1 [typing] add annotations to numpy.fft 2022-10-04 15:52:54 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Marc van Zee
9d18f43a01 Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None).
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.

PiperOrigin-RevId: 459566727
2022-07-07 11:54:39 -07:00
Jake VanderPlas
297a2969a5 [x64] make fft functionality compatible with strict dtype promotion 2022-06-15 10:10:44 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Yin Li
c5d4aba2a9 Fix fft dtype for norm='ortho' 2022-03-10 10:39:52 -05:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
iollo jacopo
67dc16fc24 add fft normalisation 2021-10-20 22:15:35 +01:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Stephan Hoyer
22943ef839 Add jit to lax.fft
The main motivation here is ensuring that FFTs are always marked in
profiler results, which is not necessarily the case where running on
TPUs.

I would jit decorate the user facing functions in jax.numpy.fft, but
these functions also accept parameters as lists, e.g., for axes, which
are mutable and hence not valid as direct input into jit decorated
functions. This might be worth doing, but would be a breaking change.
2021-08-30 09:28:35 -07:00
Joost van Doorn
7091ae5af6 Add support for padding and cropping to fft 2021-04-17 08:38:24 +02:00
Peter Hawkins
aa107cf1f4 Move jax.numpy internals into jax._src.numpy. 2020-10-16 20:35:19 -04:00