19 Commits

Author SHA1 Message Date
Jake Vanderplas
05904faf0f
Change onp/np to np/jnp in docs & notebooks (#3760) 2020-07-15 13:17:38 -07:00
Roy Frostig
8a62a9b654
block-unrolled scan primitive implementation (#3738)
* block-unrolled scan implementation, via optional `_unroll` scan parameter

* index statically in the inlined path of lax.scan

* make `unroll` a required scan parameter, and test that it unrolls
2020-07-15 14:00:50 -04:00
8bitmp3
242b382bab
Remove a deprecated reference to testExamplesJaxprDoc in Understanding Jaxpr (#3680) 2020-07-07 11:29:44 -07:00
igorwilbert
e5d4ca31a8
Fix typo understanding jaxprs page on readthedocs (#3513) 2020-06-22 12:31:08 -07:00
Roy Frostig
15bc62204e jaxpr: support dropped assignment 2020-06-09 13:47:17 -07:00
Roy Frostig
bd3cab9768 update jaxpr doc to reflect lax.switch and indexed cond 2020-06-03 22:19:15 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Roy Frostig
916953ace8 update example in jaxpr doc 2020-05-29 16:57:40 -07:00
Roy Frostig
de03c99b52 update jaxpr doc and tests with single-operand cond 2020-05-13 21:14:41 -07:00
James Bradbury
f60184e12e
Support axis_index_groups in allreduce collectives (#2382)
* support replica groups in allreduce collectives

* add test and fix jaxpr in docs

* switch from XLA replica IDs to JAX axis indices

* fix psum transpose rule

* test other nesting order + imperfect nesting

* update jaxpr.rst

* handle None case

* add note+check that groups  cover the index space

* switch split_axis assert to NotImplementedError

* update CHANGELOG
2020-05-08 14:00:34 -07:00
Matthew Johnson
3cd409ee88
add optional 'forward' argument to lax.scan (#2921)
* add optional 'forward' argument to lax.scan

* switch to reverse; revise disable-jit case

* fix jaxpr.rst

* fix loops.py

Co-authored-by: James Bradbury <jekbradbury@gmail.com>
2020-05-04 19:44:22 -07:00
Roman Ring
525235d8c9
Fix a codeblock in the "understanding jaxpr" doc. (#2942)
This fixes an issue where the codeblock didn't render properly on the website.
2020-05-04 13:20:21 +03:00
Jamie Townsend
283393f773
Update jaxpr.rst (#2859)
* Update jaxpr doc

* Make jaxpr.rst doctestable
2020-04-27 16:44:46 -07:00
Skye Wanderman-Milne
f37f235183
Fix up previous jaxpr.rst commit. (#2647) 2020-04-08 11:29:02 -07:00
Skye Wanderman-Milne
f8dc650b2a
Update scan jaxpr documentation. (#2641)
Closes #2640.
2020-04-07 19:03:41 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
370558def3 Removed a couple of slow notebooks from RTD auto-rendering.
Trying to address the timeouts in RTD rendering.

Also fixed bad itemized list in autodiff cookbook, and a few minor warnings:
Issue: #2092
2020-02-15 11:43:10 +01:00
George Necula
20dbc62277 Updated docstrings based on review comments 2020-02-13 09:28:01 +01:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00