171 Commits

Author SHA1 Message Date
George Necula
2b0b04fcad Merge remote-tracking branch 'upstream/master' into jaxpr_pp 2019-11-28 08:56:00 +01:00
George Necula
0cb3b433b5 Change in how we print sorted params for eqns 2019-11-28 07:34:40 +01:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
2bb74b627e Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output 2019-11-26 14:05:08 +01:00
Matthew Johnson
1fcebbaa0e fix reference cycle in jaxpr tracing using weakrefs
As one step in tracing user code to a jaxpr using the machinery in
partial_eval.py, we construct a bipartite graph made of JaxprTracer
nodes, corresponding to values in the user code, and recipe nodes
,particularly those corresponding to jaxpr equations, representing
primitive operations. (This representation was put in place in #1224,
since when primitives only had single outputs we could identify each
primitive operation with the JaxprTracer value it produced.) This graph
had reference cycles because each equation recipe points to both its
input and output tracers (as a jaxpr eqn has both input and output vars)
and a tracer must be able to point to the equation recipe that produced
it (for us to toposort the graph from in_tracers to out_tracers in
tracers_to_jaxpr).

Those cycles caused memory leaks. This commit removes the strong
reference cycle using weakrefs. In particular, equation recipes only
hold weak references to their output tracers.

Before this change, we used the core.JaxprEqn struct both to represent
equations in jaxprs (where invars and outvars are instances of the
core.Var class) and to represent equation recipes (where invars and
outvars are instances of the partial_eval.JaxprTracer class). That was a
bit lazy. This commit distinguishes the two as separate JaxprEqn and
JaxprEqnRecipe structs.

Bug find and test code from @trevorcai. Thanks!
2019-11-19 15:23:08 -08:00
Peter Hawkins
5c3b99d0b4
Implement the __pos__ operator on JAX arrays. (#1718) 2019-11-18 22:00:32 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
James Bradbury
bc0e79767b fix XLA metadata for primitives with many args 2019-10-08 10:57:36 -07:00
James Bradbury
4e6385f8a6 remove source lines (caching makes them ~useless) 2019-10-04 11:59:06 -07:00
James Bradbury
59343b1a23 provide lax primitive + src line as XLA debuginfo 2019-10-03 17:56:25 -07:00
Matthew Johnson
98c7567a0d add flag for logging when jit performs compilation 2019-08-23 08:17:41 -07:00
Matthew Johnson
f56312bbff remove pat_fmap 2019-08-21 13:53:57 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
3c37a3260a Update linearize to no-tuple version 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
20fad746a8 De-dup equations with multiple lhs vars when creating a jaxpr
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Jamie Townsend
9b19afd4d9 Implement Jaxpr __hash__
This means that primitives like scatter, which have a Jaxpr in their
**params, will get cache hits appropriately.
2019-08-21 14:27:23 +01:00
Brian Patton
d07107af5a
Makes the Tracer object weakref-able 2019-08-16 17:18:44 -05:00
Peter Hawkins
6dc730a5f4 Make JAX tracer state thread-local. Allows performing traces in separate threads.
Using threading within a traced context still won't work, but that is perhaps less important than the ability to call JIT-ted computations from separate threads.

(Revives https://github.com/google/jax/pull/734.)
2019-08-09 13:55:20 -04:00
Jamie Townsend
21a69884fd call_wrapped in core.call_impl 2019-07-22 17:09:03 +01:00
Matthew Johnson
5aef18f897 improve literal hashing logic
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.
2019-06-19 10:32:55 -07:00
Matthew Johnson
b53bccc5d0 make more literals nontrivially hashable 2019-06-18 21:51:51 -07:00
Matthew Johnson
221426fadc de-duplicate constants staged into jaxprs
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-06-18 08:09:37 -07:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Matthew Johnson
310103f578 try a tweak on Literal for more cache hits 2019-05-28 22:50:52 -07:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
d27bc0a129 add literals to jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-13 08:48:13 -07:00
Matthew Johnson
6e9718a229 add pretty-printing to TypedJaxpr 2019-05-11 13:28:47 -07:00
Matthew Johnson
65202821df improve core.typed_jaxpr arg typechecks 2019-05-11 10:45:14 -07:00
Matthew Johnson
4fcd96f926 make tests pass with skip_checks = False 2019-05-10 22:07:54 -07:00
Matthew Johnson
29e67f0119 scan bug fixed, other cleanup 2019-05-10 15:52:12 -07:00
Matthew Johnson
5cfa18015c fix things we broke on the path to scan 2019-05-10 14:00:21 -07:00
Matthew Johnson
360e39756f must guarantee progress on lattice...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:38 -07:00
Matthew Johnson
085f06e4b6 add some PartialVal invariants
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:24 -07:00
Matthew Johnson
c08b9fee47 remove const_env from check_jaxpr, add scan trans 2019-05-08 17:41:36 -07:00
Matthew Johnson
15d783a836 Merge remote-tracking branch 'origin/master' into differentiable-scan 2019-05-08 13:42:44 -07:00
Matthew Johnson
444cda493a add underscores, rename scan_initial -> scan 2019-05-08 13:41:27 -07:00
Matthew Johnson
e736a0a9a1 cleanup: remove call_initial, add xla pat_fmap 2019-05-08 13:41:27 -07:00
Matthew Johnson
4c2ec3e442 ship it
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:25 -07:00
Matthew Johnson
0988f6d8d5 pattern unpacking at jaxpr top-level (pair w/ @dougalm)
next step is to handle that new complexity in our jaxpr munging...

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:21 -07:00
Matthew Johnson
a17f8e4ca8 add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:19 -07:00
Matthew Johnson
1c9035efca start scan transpose, but "nonlinear pack"!!
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:17 -07:00
Matthew Johnson
6736823021 victory! patial eval of scan (+ linearize!)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:15 -07:00
Matthew Johnson
d03cdc6397 introduce typedjaxpr to carry around literals etc
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:10 -07:00
Matthew Johnson
19e0f8de45 fix tuple unpacking problems 2019-05-06 22:43:31 -07:00
Matthew Johnson
bf6c15b59a update pmap to flatten correctly (was a perf bug)
also temporarily avoid DeviceTuples in optimizer states
2019-05-06 12:09:54 -07:00
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
f95f1c8dda fix bugs, make tests pass with skip_checks = False 2019-05-03 12:01:12 -07:00
Matthew Johnson
8e96e2f6df revert incorrect change to core.valid_jaxtype 2019-05-03 08:24:24 -07:00
Matthew Johnson
7c5d683915 revise sharded result handling, misc cleanup 2019-05-03 08:06:55 -07:00