1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-25 20:06:05 +00:00

25 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Peter Hawkins
3bcba4ade9 Add shape checks for lax.fft.
Fixes https://github.com/google/jax/issues/4734
2022-08-13 16:32:52 +00:00
Xin Zhou
78f9b9247b [mhlo] Add result type inference for mhlo.fft.
PiperOrigin-RevId: 466081681
2022-08-08 10:32:07 -07:00
Jake VanderPlas
297a2969a5 [x64] make fft functionality compatible with strict dtype promotion 2022-06-15 10:10:44 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Peter Hawkins
4806c29bf7 [MHLO] Add MHLO lowerings for FFT ops.
PiperOrigin-RevId: 441768017
2022-04-14 08:31:17 -07:00
Peter Hawkins
84bccb2420 Support string fft_type values in lax.fft. 2022-02-03 08:52:38 -05:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Peter Hawkins
e783cbcb72 Port remaining translation rules inside JAX to new style.
PiperOrigin-RevId: 404288551
2021-10-19 09:48:37 -07:00
Peter Hawkins
8c3b212dd6 Improve real type conversion in a couple more places. 2021-10-18 13:50:11 -04:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Stephan Hoyer
22943ef839 Add jit to lax.fft
The main motivation here is ensuring that FFTs are always marked in
profiler results, which is not necessarily the case where running on
TPUs.

I would jit decorate the user facing functions in jax.numpy.fft, but
these functions also accept parameters as lists, e.g., for axes, which
are mutable and hence not valid as direct input into jit decorated
functions. This might be worth doing, but would be a breaking change.
2021-08-30 09:28:35 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Stephan Hoyer
bc2e42807e Fix transpose rule for jnp.fft.irfft
Fixes 
2021-03-25 18:49:38 -07:00
James Bradbury
f1918f0b19 [avals with names] Revise aval constructor call sites to use a new aval.update method
PiperOrigin-RevId: 354182876
2021-01-27 15:14:02 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
98aac23d92 Change from deflinear to deflinear2 2021-01-05 09:03:33 -08:00
Peter Hawkins
424594feb2 Short-circuit references to jax.core via jax.abstract_arrays. 2020-11-19 14:15:28 -05:00
Benjamin Chetioui
ad63d8d6a9 Cleanup outdated jaxlib TODOs in jax/_src/lax/fft.py 2020-11-03 16:01:44 +01:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in  on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Peter Hawkins
10b7d7d7c2 Move implementation of jax.lax into jax._src.lax.
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00