Qiumin Xu
31600aac62
Add named_call public API.
...
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
jax authors
3e7285b436
Merge pull request #4301 from JuliusKunze:mask-polynomial-division
...
PiperOrigin-RevId: 341256792
2020-11-07 23:52:26 -08:00
Peter Hawkins
79cf7d11ab
Fix max/min confusion in lax.py.
2020-11-05 09:24:18 -05:00
Qiumin Xu
70f03eb63e
Upstream named_call to jax
...
Upstream the implementation of named_call to JAX. (there are equivalent implementations in Haiku and Flax)
Reference:
Flax implementation:
https://github.com/google/flax/blob/master/flax/core/named_call.py
Haiku implementation:
https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/named_call.py
2020-11-04 13:26:28 -08:00
Jean-Baptiste Lespiau
3e5a0ff0c4
Add methods to interact with DeviceArray
objects.
...
We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,
- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
2020-11-03 22:16:28 +01:00
Jean-Baptiste Lespiau
6b59a2057b
Fix some indentation.
2020-11-03 16:39:32 +01:00
Julius Kunze
d29b69a3f8
Support polynomial division for mask
...
Add support for multivariate polynomial division on polymorphic sizes without remainder, allowing `mask` of
- `jnp.reshape` with -1 size and
- `lax.slice` for polymorphic stride for sizes `poly * stride`, i. e. `(n^2+2n)//n = n+2`
Also clean up `Poly` class, improve error messages.
2020-11-02 21:26:22 +01:00
Matthew Johnson
7a73e99e14
simplify select jvp
...
also remove some coverage of broadcast_p, which jax never generates now
2020-10-26 17:45:40 -07:00
Matthew Johnson
9ba28d2634
Copybara import of the project:
...
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:
redo #4535 lazy simplification
PiperOrigin-RevId: 338670328
2020-10-23 07:35:01 -07:00
Matthew Johnson
fcaced32aa
Copybara import of the project:
...
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:
redo #4535 lazy simplification
PiperOrigin-RevId: 338606348
2020-10-22 21:18:22 -07:00
jax authors
8ff6396159
Merge pull request #4686 from google:lazy-simplification-again
...
PiperOrigin-RevId: 338582074
2020-10-22 17:29:55 -07:00
Matthew Johnson
ced333d1d4
redo #4535 lazy simplification
2020-10-22 16:56:29 -07:00
Matthew Johnson
f40ac06717
make lax.dynamic_slice transpose handle symb zeros
2020-10-22 15:31:43 -07:00
Roman Novak
da0bff2fa8
Add lax.conv_general_dilated_patches
2020-10-20 22:58:53 -07:00
Peter Hawkins
10b7d7d7c2
Move implementation of jax.lax into jax._src.lax.
...
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00