715 Commits

Author SHA1 Message Date
Qiumin Xu
31600aac62 Add named_call public API.
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
jax authors
3e7285b436 Merge pull request #4301 from JuliusKunze:mask-polynomial-division
PiperOrigin-RevId: 341256792
2020-11-07 23:52:26 -08:00
Peter Hawkins
79cf7d11ab Fix max/min confusion in lax.py. 2020-11-05 09:24:18 -05:00
Qiumin Xu
70f03eb63e Upstream named_call to jax
Upstream the implementation of named_call to JAX. (there are equivalent implementations in Haiku and Flax)

Reference:
Flax implementation:
https://github.com/google/flax/blob/master/flax/core/named_call.py

Haiku implementation:
https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/named_call.py
2020-11-04 13:26:28 -08:00
Jean-Baptiste Lespiau
3e5a0ff0c4 Add methods to interact with DeviceArray objects.
We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,

- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
2020-11-03 22:16:28 +01:00
Jean-Baptiste Lespiau
6b59a2057b Fix some indentation. 2020-11-03 16:39:32 +01:00
Julius Kunze
d29b69a3f8 Support polynomial division for mask
Add support for multivariate polynomial division on polymorphic sizes without remainder, allowing `mask` of

- `jnp.reshape` with -1 size and
- `lax.slice` for polymorphic stride for sizes  `poly * stride`, i. e. `(n^2+2n)//n = n+2`

Also clean up `Poly` class, improve error messages.
2020-11-02 21:26:22 +01:00
Matthew Johnson
7a73e99e14 simplify select jvp
also remove some coverage of broadcast_p, which jax never generates now
2020-10-26 17:45:40 -07:00
Matthew Johnson
9ba28d2634 Copybara import of the project:
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:

redo #4535 lazy simplification

PiperOrigin-RevId: 338670328
2020-10-23 07:35:01 -07:00
Matthew Johnson
fcaced32aa Copybara import of the project:
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:

redo #4535 lazy simplification

PiperOrigin-RevId: 338606348
2020-10-22 21:18:22 -07:00
jax authors
8ff6396159 Merge pull request #4686 from google:lazy-simplification-again
PiperOrigin-RevId: 338582074
2020-10-22 17:29:55 -07:00
Matthew Johnson
ced333d1d4 redo #4535 lazy simplification 2020-10-22 16:56:29 -07:00
Matthew Johnson
f40ac06717 make lax.dynamic_slice transpose handle symb zeros 2020-10-22 15:31:43 -07:00
Roman Novak
da0bff2fa8 Add lax.conv_general_dilated_patches 2020-10-20 22:58:53 -07:00
Peter Hawkins
10b7d7d7c2 Move implementation of jax.lax into jax._src.lax.
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00