362 Commits

Author SHA1 Message Date
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 2020-08-18 10:17:38 -07:00
Roy Frostig
dbca9e682c unrevert #3674 (revert #3791) 2020-08-17 18:13:58 -07:00
Matthew Johnson
9120701188
allow xla_computation to psum a constant (#4078)
* allow xla_computation to psum a constant

* allow axis_env to be None
2020-08-16 20:00:40 -07:00
Matthew Johnson
8232f2deee
adapt _TempAxisName for unhashable objs (#4077)
adapt _TempAxisName for unhashable objs
2020-08-15 22:55:18 -07:00
Ryan Sepassi
394a33c828
Add in_parts and out_parts optional arguments jax.xla_computation. (#4055)
PR #3771 redux (reverted in #3780)

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-08-14 13:05:58 -07:00
Adam Paszke
b75bae6437
Initial version of vmap collectives (#4005)
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.

This feature is also omnistaging-only.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-08-14 18:22:04 +02:00
Roy Frostig
34f90f55e6 remove auto-parallelization transformation 2020-08-12 13:44:38 -07:00
Jamie Townsend
1e07712955
Fix typos in api.py docstrings (#4021) 2020-08-11 07:09:36 -07:00
Adam Paszke
cd54dd9778
Implement the invertible decorator in terms of custom_vjp (#3957)
This simplifies the implementation significantly, as we can piggyback
off of all the logic for custom derivatives. For example, the previous
implementation didn't support differentiating with respect to a subset
of function parameters, but the new one does.
2020-08-11 11:45:58 +02:00
Matthew Johnson
6a3b920507
make make_jaxpr work on tracer example args (#4014)
(don't use xla.abstractify)
2020-08-10 18:11:57 -07:00
John Aslanides
038c85dad0
Improve type annotations for jit and vmap. (#3938) 2020-08-08 12:22:54 -04:00
Adrià Puigdomènech
d4d7323a57
Avoid re-flattening in jit() when no donate_argnums are present. (#3955)
Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.
2020-08-04 16:45:03 +03:00
George Necula
9fe7a3c33b
Revert "Avoid re-flattening in jit() when no donate_argnums are present. (#3945)" (#3953)
This reverts commit 4e873f417ab4f3e68a102548167f2b0a005edfad.

See comments in #3945 about the failure.
2020-08-04 13:20:30 +03:00
Adrià Puigdomènech
4e873f417a
Avoid re-flattening in jit() when no donate_argnums are present. (#3945)
Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.
2020-08-03 15:13:28 -07:00
John Aslanides
8a8bb702d2
Catch invalid (negative) in/out axes in vmap. (#3926)
Catch invalid (negative) in/out axes in vmap.
2020-08-01 13:33:11 -07:00
Jean-Baptiste Lespiau
4853eb103c
Fix static_argnums in xla_computation. (#3924) 2020-07-31 15:15:51 -07:00
Matthew Johnson
843d710116
allow mask to return output logical shape (#3929)
When the `mask` argument `out_shape` is not provided, or when it has
value `None`, return the output logical shape to the user.
2020-07-31 15:11:01 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
616d63b19a
fix vmap error, fixes #3877 (#3879) 2020-07-27 21:51:12 -07:00
Matthew Johnson
67ad5eb5a1
add result_shape option to xla_computation (#3844)
add result_shape option to xla_computation
2020-07-23 19:38:56 -07:00
Roy Frostig
fa2a0275c8 revert #3674 2020-07-17 15:44:51 -07:00
Skye Wanderman-Milne
2df486ee83
Note in pmap docs that pmap compiles like jit. (#3787) 2020-07-17 15:11:26 -07:00
Roy Frostig
6416ca0e9d append filtered stack traces to error messages raised under transformations 2020-07-16 17:12:09 -07:00
Skye Wanderman-Milne
44fbce56f9
Revert "Add in_parts and out_parts optional arguments jax.xla_computation. (#3771)" (#3780)
This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d.

This is breaking some google-internal users of xla_computation. Reverting while I investigate.
2020-07-16 15:22:40 -07:00
Skye Wanderman-Milne
dbc3f83f6d
Add in_parts and out_parts optional arguments jax.xla_computation. (#3771)
This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`.
2020-07-15 14:56:58 -07:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
Neil Girdhar
503e5973ce Make vjp cotangent functions pytree-like
Fixes #3667
2020-07-10 23:22:38 -04:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
Matthew Johnson
107689e91f
improve vmap axis spec structure mismatch errors (#3619)
* improve vmap axis spec structure mismatch errors

fixes #3613

* deflake
2020-06-30 22:19:16 -07:00
Tom Hennigan
2ef39abf54
Retain original docstring when vmap'ing functions. (#3592) 2020-06-29 21:16:02 -07:00
Matthew Johnson
696958d2bd
add remat docstring (#3542)
* add remat docstring

first part of addressing #3314
2020-06-24 16:11:26 -07:00
Jony Hudson
677baa54dd
Clarify docstrings regarding usage of static arguments in jit and vmap. (#3484)
The docstring for pmap does not currently mention that any "non-data"
arguments need to be indicated in `static_broadcasted_argnums`. This
commit updates the docs to parallel those for `jax.jit` which does
explain this. Additionally, a remark is added to the `static_*` argument
descriptions on both jit and pmap, so that this point can be understood
without reading the whole docstring.
2020-06-24 15:47:09 -04:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing (#3491) 2020-06-23 09:39:45 -07:00
Peter Hawkins
86fcfbfa1a
Fix memory leak when no axis is provided to pmap. (#3394)
* Fix memory leak when no axis is provided to pmap.

* Work around flake8 false positive.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-23 09:29:58 -04:00
Matthew Johnson
2f7108f78b
remove the lower_fun default multiple_results=True (#3524) 2020-06-22 17:50:33 -07:00
Skye Wanderman-Milne
8f4ba7e679
Allow specifying both devices and axis_size to pmap. (#3475)
This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform.
2020-06-19 15:51:12 -07:00
Matthew Johnson
fe14aa3e00
Merge branch 'master' into changelist/316251368 2020-06-15 22:15:34 -07:00
Matthew Johnson
d4c6cb62ab print warning when doing jit-of-pmap 2020-06-15 21:37:30 -07:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
David Budden
3c78605bb8 Propagate raw __name__ and __doc__ of functions wrapped by jit and sharded_jit. 2020-06-15 10:32:09 +01:00
Peter Hawkins
04c9b32788
Small edit to documentation. (#3406) 2020-06-11 17:11:32 -04:00
Matthew Johnson
ee428008c4 yet another doc fix 2020-06-08 15:06:00 -07:00
George Necula
65d95f10ea
A couple of ad_util.zero were missed in #3222 (#3363) 2020-06-08 19:59:25 +03:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Peter Hawkins
96d72fca56
Edit documentation. (#3358)
Use :func:, :class: and :meth: when referring to Python objects.
Use :ref: for hyperlinks.
Fix some bad formatting.
2020-06-08 10:37:50 -04:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Adam Paszke
74d160f5e0
Don't keep primal arguments and results in the linearized jaxpr (#3233)
Linearized functions are supposed to take tangent types to tangent
types, and so all primal arguments are unused and primal results get
replaced by units.
2020-06-05 17:22:55 +02:00
Julius Kunze
d1dbf7c7d8
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-03 13:40:48 -07:00