29 Commits

Author SHA1 Message Date
Matthew Johnson
e88579f22b fix typo 2020-09-18 19:41:59 -07:00
Matthew Johnson
f172fb74e1 plumb donate_argnums into jax.xla_computation 2020-09-18 17:39:05 -07:00
Matthew Johnson
107689e91f
improve vmap axis spec structure mismatch errors (#3619)
* improve vmap axis spec structure mismatch errors

fixes #3613

* deflake
2020-06-30 22:19:16 -07:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Skye Wanderman-Milne
a5da921f4c
Move _flatten_axes to api_util.py (#3041)
This is in preparation for using it in sharded_jit.py (since sharded_jit isn't included in api.py yet).
2020-05-11 11:04:57 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
George Necula
528a69f32e Added some more documentation to the linear_util module
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00
Matthew Johnson
979b38352f make vmap structured axes work for any pytree 2019-10-31 14:09:12 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
c53c8bbb43 Some progress de-tupling ad.py 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Peter Hawkins
476dc3db64 Python changes in preparation for adding a C++ implementation of the PyTree utilities. 2019-07-29 10:57:27 -04:00
Matthew Johnson
0546c94992 speed up pmap axis-size getting
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-07-25 12:41:31 -07:00
Matthew Johnson
b6031ffdd7 avoid packing leaf outputs for jit/pmap funs 2019-05-17 07:36:52 -07:00
Matthew Johnson
15a4554ffb flatten out pytrees in jit at the api.py level 2019-05-03 11:39:37 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
902c149c47 add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:

```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x)  # relu
```

The partial evaluation logic works with tuples, so this works too:

```python
lax.cond(x < 0,
         x, lambda x: (x, x, 1, 1, 1),
         x, lambda x: (x, 1, x, 1, 2))
```

in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.

For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
Peter Hawkins
3e25d290be Set __wrapped__ attribute instead of using functools.wraps to fix Python 2.7 problem. 2019-02-14 11:00:40 -05:00
Peter Hawkins
33cd3d0299 Use functools.wraps as the basis for api_util.wraps.
Fixes API signatures in `jax.random` documentation (https://github.com/google/jax/issues/370).
2019-02-14 10:07:47 -05:00
Matthew Johnson
da2d185444 tweak 2019-01-28 09:19:06 -08:00
Matthew Johnson
945fa34e7e tweaks 2019-01-28 09:00:02 -08:00
Matthew Johnson
780106f892 moving pxla flattening/chunking to api.py, wip 2019-01-28 08:38:14 -08:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00
Matthew Johnson
ad4322c5da playing around with flattening functions 2019-01-06 12:49:35 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00