22 Commits

Author SHA1 Message Date
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
George Necula
528a69f32e Added some more documentation to the linear_util module
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00
Matthew Johnson
979b38352f make vmap structured axes work for any pytree 2019-10-31 14:09:12 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
c53c8bbb43 Some progress de-tupling ad.py 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Peter Hawkins
476dc3db64 Python changes in preparation for adding a C++ implementation of the PyTree utilities. 2019-07-29 10:57:27 -04:00
Matthew Johnson
0546c94992 speed up pmap axis-size getting
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-07-25 12:41:31 -07:00
Matthew Johnson
b6031ffdd7 avoid packing leaf outputs for jit/pmap funs 2019-05-17 07:36:52 -07:00
Matthew Johnson
15a4554ffb flatten out pytrees in jit at the api.py level 2019-05-03 11:39:37 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
902c149c47 add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:

```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x)  # relu
```

The partial evaluation logic works with tuples, so this works too:

```python
lax.cond(x < 0,
         x, lambda x: (x, x, 1, 1, 1),
         x, lambda x: (x, 1, x, 1, 2))
```

in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.

For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
Peter Hawkins
3e25d290be Set __wrapped__ attribute instead of using functools.wraps to fix Python 2.7 problem. 2019-02-14 11:00:40 -05:00
Peter Hawkins
33cd3d0299 Use functools.wraps as the basis for api_util.wraps.
Fixes API signatures in `jax.random` documentation (https://github.com/google/jax/issues/370).
2019-02-14 10:07:47 -05:00
Matthew Johnson
da2d185444 tweak 2019-01-28 09:19:06 -08:00
Matthew Johnson
945fa34e7e tweaks 2019-01-28 09:00:02 -08:00
Matthew Johnson
780106f892 moving pxla flattening/chunking to api.py, wip 2019-01-28 08:38:14 -08:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00
Matthew Johnson
ad4322c5da playing around with flattening functions 2019-01-06 12:49:35 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00