1816 Commits

Author SHA1 Message Date
jax authors
4b25b77718 Merge pull request #4766 from bchetioui:cleanup_fft
PiperOrigin-RevId: 340549622
2020-11-03 16:37:46 -08:00
Jean-Baptiste Lespiau
3e5a0ff0c4 Add methods to interact with DeviceArray objects.
We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,

- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
2020-11-03 22:16:28 +01:00
Jean-Baptiste Lespiau
6b59a2057b Fix some indentation. 2020-11-03 16:39:32 +01:00
Benjamin Chetioui
ad63d8d6a9 Cleanup outdated jaxlib TODOs in jax/_src/lax/fft.py 2020-11-03 16:01:44 +01:00
Julius Kunze
d29b69a3f8 Support polynomial division for mask
Add support for multivariate polynomial division on polymorphic sizes without remainder, allowing `mask` of

- `jnp.reshape` with -1 size and
- `lax.slice` for polymorphic stride for sizes  `poly * stride`, i. e. `(n^2+2n)//n = n+2`

Also clean up `Poly` class, improve error messages.
2020-11-02 21:26:22 +01:00
jax authors
d158647c83 Merge pull request #4706 from apaszke:vmap-collectives-in-scan
PiperOrigin-RevId: 339646941
2020-10-29 05:11:23 -07:00
Matthew Johnson
7a73e99e14 simplify select jvp
also remove some coverage of broadcast_p, which jax never generates now
2020-10-26 17:45:40 -07:00
Adam Paszke
6348a99fb4 Add support for vmap collectives in control flow primitives
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
2020-10-26 12:09:18 +00:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Matthew Johnson
9ba28d2634 Copybara import of the project:
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:

redo #4535 lazy simplification

PiperOrigin-RevId: 338670328
2020-10-23 07:35:01 -07:00
Matthew Johnson
fcaced32aa Copybara import of the project:
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:

redo #4535 lazy simplification

PiperOrigin-RevId: 338606348
2020-10-22 21:18:22 -07:00
jax authors
8ff6396159 Merge pull request #4686 from google:lazy-simplification-again
PiperOrigin-RevId: 338582074
2020-10-22 17:29:55 -07:00
Matthew Johnson
ced333d1d4 redo #4535 lazy simplification 2020-10-22 16:56:29 -07:00
Matthew Johnson
f40ac06717 make lax.dynamic_slice transpose handle symb zeros 2020-10-22 15:31:43 -07:00
Roman Novak
da0bff2fa8 Add lax.conv_general_dilated_patches 2020-10-20 22:58:53 -07:00
Peter Hawkins
10b7d7d7c2 Move implementation of jax.lax into jax._src.lax.
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00