350 Commits

Author SHA1 Message Date
George Necula
9fe7a3c33b
Revert "Avoid re-flattening in jit() when no donate_argnums are present. (#3945)" (#3953)
This reverts commit 4e873f417ab4f3e68a102548167f2b0a005edfad.

See comments in #3945 about the failure.
2020-08-04 13:20:30 +03:00
Adrià Puigdomènech
4e873f417a
Avoid re-flattening in jit() when no donate_argnums are present. (#3945)
Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.
2020-08-03 15:13:28 -07:00
John Aslanides
8a8bb702d2
Catch invalid (negative) in/out axes in vmap. (#3926)
Catch invalid (negative) in/out axes in vmap.
2020-08-01 13:33:11 -07:00
Jean-Baptiste Lespiau
4853eb103c
Fix static_argnums in xla_computation. (#3924) 2020-07-31 15:15:51 -07:00
Matthew Johnson
843d710116
allow mask to return output logical shape (#3929)
When the `mask` argument `out_shape` is not provided, or when it has
value `None`, return the output logical shape to the user.
2020-07-31 15:11:01 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
616d63b19a
fix vmap error, fixes #3877 (#3879) 2020-07-27 21:51:12 -07:00
Matthew Johnson
67ad5eb5a1
add result_shape option to xla_computation (#3844)
add result_shape option to xla_computation
2020-07-23 19:38:56 -07:00
Roy Frostig
fa2a0275c8 revert #3674 2020-07-17 15:44:51 -07:00
Skye Wanderman-Milne
2df486ee83
Note in pmap docs that pmap compiles like jit. (#3787) 2020-07-17 15:11:26 -07:00
Roy Frostig
6416ca0e9d append filtered stack traces to error messages raised under transformations 2020-07-16 17:12:09 -07:00
Skye Wanderman-Milne
44fbce56f9
Revert "Add in_parts and out_parts optional arguments jax.xla_computation. (#3771)" (#3780)
This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d.

This is breaking some google-internal users of xla_computation. Reverting while I investigate.
2020-07-16 15:22:40 -07:00
Skye Wanderman-Milne
dbc3f83f6d
Add in_parts and out_parts optional arguments jax.xla_computation. (#3771)
This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`.
2020-07-15 14:56:58 -07:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
Neil Girdhar
503e5973ce Make vjp cotangent functions pytree-like
Fixes #3667
2020-07-10 23:22:38 -04:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
Matthew Johnson
107689e91f
improve vmap axis spec structure mismatch errors (#3619)
* improve vmap axis spec structure mismatch errors

fixes #3613

* deflake
2020-06-30 22:19:16 -07:00
Tom Hennigan
2ef39abf54
Retain original docstring when vmap'ing functions. (#3592) 2020-06-29 21:16:02 -07:00
Matthew Johnson
696958d2bd
add remat docstring (#3542)
* add remat docstring

first part of addressing #3314
2020-06-24 16:11:26 -07:00
Jony Hudson
677baa54dd
Clarify docstrings regarding usage of static arguments in jit and vmap. (#3484)
The docstring for pmap does not currently mention that any "non-data"
arguments need to be indicated in `static_broadcasted_argnums`. This
commit updates the docs to parallel those for `jax.jit` which does
explain this. Additionally, a remark is added to the `static_*` argument
descriptions on both jit and pmap, so that this point can be understood
without reading the whole docstring.
2020-06-24 15:47:09 -04:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing (#3491) 2020-06-23 09:39:45 -07:00
Peter Hawkins
86fcfbfa1a
Fix memory leak when no axis is provided to pmap. (#3394)
* Fix memory leak when no axis is provided to pmap.

* Work around flake8 false positive.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-23 09:29:58 -04:00
Matthew Johnson
2f7108f78b
remove the lower_fun default multiple_results=True (#3524) 2020-06-22 17:50:33 -07:00
Skye Wanderman-Milne
8f4ba7e679
Allow specifying both devices and axis_size to pmap. (#3475)
This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform.
2020-06-19 15:51:12 -07:00
Matthew Johnson
fe14aa3e00
Merge branch 'master' into changelist/316251368 2020-06-15 22:15:34 -07:00
Matthew Johnson
d4c6cb62ab print warning when doing jit-of-pmap 2020-06-15 21:37:30 -07:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
David Budden
3c78605bb8 Propagate raw __name__ and __doc__ of functions wrapped by jit and sharded_jit. 2020-06-15 10:32:09 +01:00
Peter Hawkins
04c9b32788
Small edit to documentation. (#3406) 2020-06-11 17:11:32 -04:00
Matthew Johnson
ee428008c4 yet another doc fix 2020-06-08 15:06:00 -07:00
George Necula
65d95f10ea
A couple of ad_util.zero were missed in #3222 (#3363) 2020-06-08 19:59:25 +03:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Peter Hawkins
96d72fca56
Edit documentation. (#3358)
Use :func:, :class: and :meth: when referring to Python objects.
Use :ref: for hyperlinks.
Fix some bad formatting.
2020-06-08 10:37:50 -04:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Adam Paszke
74d160f5e0
Don't keep primal arguments and results in the linearized jaxpr (#3233)
Linearized functions are supposed to take tangent types to tangent
types, and so all primal arguments are unused and primal results get
replaced by units.
2020-06-05 17:22:55 +02:00
Julius Kunze
d1dbf7c7d8
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-03 13:40:48 -07:00
Matthew Johnson
b58eec51ac
make pmap axis checking an exception, hoist (#3239) 2020-06-02 20:28:59 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
Matthew Johnson
5c1de2836c
revise vmap in_axes/out_axes leaf type error msg (#3179)
from #3161 discussion
2020-05-21 08:00:18 -07:00
Matthew Johnson
eb81d7e7ff
add dict in_axes example to vmap docstring (#3176)
* add dict in_axes example to vmap docstring

fixes #3161

* fix typo
2020-05-21 06:47:02 -07:00
Matthew Johnson
ccb203c894
improve pmap unbound axis error, fixes #3120 (#3152) 2020-05-19 15:51:07 -07:00
Matthew Johnson
850f1afd95
improve errors for complex derivs, fixes #3121 (#3149) 2020-05-19 15:17:03 -07:00
Jamie Townsend
670fab59cf
Test code in docs and api.py docstrings (#2994)
Also remove jaxpr doc tests from api_test.py.
2020-05-16 16:19:24 +03:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Skye Wanderman-Milne
a5da921f4c
Move _flatten_axes to api_util.py (#3041)
This is in preparation for using it in sharded_jit.py (since sharded_jit isn't included in api.py yet).
2020-05-11 11:04:57 -07:00
George Necula
b3ae01d179
Use a new variable for static_broadcasted_argnums as a tuple. (#3027)
* Use a new variable for static_broadcasted_argnums as a tuple.

This works around a bug in pytype (b/156151503).
2020-05-10 13:16:16 +03:00
Matthew Johnson
2b622943f4
improve pmap static broadcasted kwarg error msg (#3018)
fixes #3007
2020-05-08 17:58:02 -07:00