339 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
44fbce56f9
Revert "Add in_parts and out_parts optional arguments jax.xla_computation. (#3771)" (#3780)
This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d.

This is breaking some google-internal users of xla_computation. Reverting while I investigate.
2020-07-16 15:22:40 -07:00
Skye Wanderman-Milne
dbc3f83f6d
Add in_parts and out_parts optional arguments jax.xla_computation. (#3771)
This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`.
2020-07-15 14:56:58 -07:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
Neil Girdhar
503e5973ce Make vjp cotangent functions pytree-like
Fixes #3667
2020-07-10 23:22:38 -04:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
Matthew Johnson
107689e91f
improve vmap axis spec structure mismatch errors (#3619)
* improve vmap axis spec structure mismatch errors

fixes #3613

* deflake
2020-06-30 22:19:16 -07:00
Tom Hennigan
2ef39abf54
Retain original docstring when vmap'ing functions. (#3592) 2020-06-29 21:16:02 -07:00
Matthew Johnson
696958d2bd
add remat docstring (#3542)
* add remat docstring

first part of addressing #3314
2020-06-24 16:11:26 -07:00
Jony Hudson
677baa54dd
Clarify docstrings regarding usage of static arguments in jit and vmap. (#3484)
The docstring for pmap does not currently mention that any "non-data"
arguments need to be indicated in `static_broadcasted_argnums`. This
commit updates the docs to parallel those for `jax.jit` which does
explain this. Additionally, a remark is added to the `static_*` argument
descriptions on both jit and pmap, so that this point can be understood
without reading the whole docstring.
2020-06-24 15:47:09 -04:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing (#3491) 2020-06-23 09:39:45 -07:00
Peter Hawkins
86fcfbfa1a
Fix memory leak when no axis is provided to pmap. (#3394)
* Fix memory leak when no axis is provided to pmap.

* Work around flake8 false positive.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-23 09:29:58 -04:00
Matthew Johnson
2f7108f78b
remove the lower_fun default multiple_results=True (#3524) 2020-06-22 17:50:33 -07:00
Skye Wanderman-Milne
8f4ba7e679
Allow specifying both devices and axis_size to pmap. (#3475)
This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform.
2020-06-19 15:51:12 -07:00
Matthew Johnson
fe14aa3e00
Merge branch 'master' into changelist/316251368 2020-06-15 22:15:34 -07:00
Matthew Johnson
d4c6cb62ab print warning when doing jit-of-pmap 2020-06-15 21:37:30 -07:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
David Budden
3c78605bb8 Propagate raw __name__ and __doc__ of functions wrapped by jit and sharded_jit. 2020-06-15 10:32:09 +01:00
Peter Hawkins
04c9b32788
Small edit to documentation. (#3406) 2020-06-11 17:11:32 -04:00
Matthew Johnson
ee428008c4 yet another doc fix 2020-06-08 15:06:00 -07:00
George Necula
65d95f10ea
A couple of ad_util.zero were missed in #3222 (#3363) 2020-06-08 19:59:25 +03:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Peter Hawkins
96d72fca56
Edit documentation. (#3358)
Use :func:, :class: and :meth: when referring to Python objects.
Use :ref: for hyperlinks.
Fix some bad formatting.
2020-06-08 10:37:50 -04:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Adam Paszke
74d160f5e0
Don't keep primal arguments and results in the linearized jaxpr (#3233)
Linearized functions are supposed to take tangent types to tangent
types, and so all primal arguments are unused and primal results get
replaced by units.
2020-06-05 17:22:55 +02:00
Julius Kunze
d1dbf7c7d8
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-03 13:40:48 -07:00
Matthew Johnson
b58eec51ac
make pmap axis checking an exception, hoist (#3239) 2020-06-02 20:28:59 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
Matthew Johnson
5c1de2836c
revise vmap in_axes/out_axes leaf type error msg (#3179)
from #3161 discussion
2020-05-21 08:00:18 -07:00
Matthew Johnson
eb81d7e7ff
add dict in_axes example to vmap docstring (#3176)
* add dict in_axes example to vmap docstring

fixes #3161

* fix typo
2020-05-21 06:47:02 -07:00
Matthew Johnson
ccb203c894
improve pmap unbound axis error, fixes #3120 (#3152) 2020-05-19 15:51:07 -07:00
Matthew Johnson
850f1afd95
improve errors for complex derivs, fixes #3121 (#3149) 2020-05-19 15:17:03 -07:00
Jamie Townsend
670fab59cf
Test code in docs and api.py docstrings (#2994)
Also remove jaxpr doc tests from api_test.py.
2020-05-16 16:19:24 +03:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Skye Wanderman-Milne
a5da921f4c
Move _flatten_axes to api_util.py (#3041)
This is in preparation for using it in sharded_jit.py (since sharded_jit isn't included in api.py yet).
2020-05-11 11:04:57 -07:00
George Necula
b3ae01d179
Use a new variable for static_broadcasted_argnums as a tuple. (#3027)
* Use a new variable for static_broadcasted_argnums as a tuple.

This works around a bug in pytype (b/156151503).
2020-05-10 13:16:16 +03:00
Matthew Johnson
2b622943f4
improve pmap static broadcasted kwarg error msg (#3018)
fixes #3007
2020-05-08 17:58:02 -07:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
d8b75e1913 Reimplemented the passing of tokens with a Jaxpr transform 2020-05-07 16:24:13 +03:00
George Necula
304009d772 Added error checking when starting compiled computations without starting
the outfeed receiver.
2020-05-07 16:24:13 +03:00
George Necula
931cb3f684 Ensure that we carry state only for control-flow conditionals that use print 2020-05-07 16:24:13 +03:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
Tom Hennigan
4c2c5ad5f4
Add a note about jax.pmap when leading dim is smaller than num devices. (#2949) 2020-05-04 15:46:12 -07:00
George Necula
d315564ebf
Fixed a few more places where device commitment was lost. (#2913)
* trivial jit computations were forcing commitment to the default device
* a device_put with a device specification would not set the commitment
  if the data was already (uncommitted) on the specified device.
* added tests for the above
* once the above were fixed the LaztTest.test_zeros_ones_compilation
  stated to fail because the `sticky` parameter to lazy_force_computation
  was changing. Fixed this by removing stickyness from the compilation key.
* Expanded docstring for jax.device_put; expanded the
  device placement FAQ entry.
2020-05-04 11:30:28 +03:00
James Bradbury
1cdd8f1b99
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)
* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
2020-05-01 14:37:13 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00