87 Commits

Author SHA1 Message Date
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
George Necula
20f9230f6e Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in params.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
2020-02-11 10:06:08 +01:00
George Necula
ae3003e9d4 Simplify bound_subjaxprs.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.

We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)

This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
2020-02-06 09:34:53 +01:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
1c134f8a6d
Rename Tracer.trace to Tracer._trace. (#2114)
Makes the .trace() method work on arrays.
2020-01-29 16:23:27 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
Matthew Johnson
07260f6572
remove hasing methods from core.Literal (#2038) 2020-01-22 17:19:14 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Roy Frostig
afb8af19ff implement JVP of while loop
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-01-10 15:31:51 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Tom Hennigan
e2f7f37f59 Allow Var to be a key in a JAX tree. 2020-01-06 22:59:55 -08:00
George Necula
ea9e93282c Fix bug with caching in presence of JVP and JIT
The bug was that the auxiliary output of the process_env_traces
was mutated before the next cache hit, so the cache content changed.

Fixes: #1945
2020-01-05 16:37:27 +01:00
George van den Driessche
a73106b37c Avoid stack overflow when JITting a function that uses copy.copy or copy.deepcopy. (#1834) 2019-12-10 21:48:51 -05:00
Matthew Johnson
7083b0a78e roll back previous commit #1829
There was a mysterious failure on an internal test, and that
mysteriousness means I didn't fully understand the attempted fix, so
best to roll back for now.
2019-12-06 22:28:57 -08:00
Matthew Johnson
80f455d3f0 make eval_jaxpr get jit cache hits 2019-12-06 21:10:52 -08:00
George Necula
2b0b04fcad Merge remote-tracking branch 'upstream/master' into jaxpr_pp 2019-11-28 08:56:00 +01:00
George Necula
0cb3b433b5 Change in how we print sorted params for eqns 2019-11-28 07:34:40 +01:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
2bb74b627e Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output 2019-11-26 14:05:08 +01:00
Matthew Johnson
1fcebbaa0e fix reference cycle in jaxpr tracing using weakrefs
As one step in tracing user code to a jaxpr using the machinery in
partial_eval.py, we construct a bipartite graph made of JaxprTracer
nodes, corresponding to values in the user code, and recipe nodes
,particularly those corresponding to jaxpr equations, representing
primitive operations. (This representation was put in place in #1224,
since when primitives only had single outputs we could identify each
primitive operation with the JaxprTracer value it produced.) This graph
had reference cycles because each equation recipe points to both its
input and output tracers (as a jaxpr eqn has both input and output vars)
and a tracer must be able to point to the equation recipe that produced
it (for us to toposort the graph from in_tracers to out_tracers in
tracers_to_jaxpr).

Those cycles caused memory leaks. This commit removes the strong
reference cycle using weakrefs. In particular, equation recipes only
hold weak references to their output tracers.

Before this change, we used the core.JaxprEqn struct both to represent
equations in jaxprs (where invars and outvars are instances of the
core.Var class) and to represent equation recipes (where invars and
outvars are instances of the partial_eval.JaxprTracer class). That was a
bit lazy. This commit distinguishes the two as separate JaxprEqn and
JaxprEqnRecipe structs.

Bug find and test code from @trevorcai. Thanks!
2019-11-19 15:23:08 -08:00
Peter Hawkins
5c3b99d0b4
Implement the __pos__ operator on JAX arrays. (#1718) 2019-11-18 22:00:32 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
James Bradbury
bc0e79767b fix XLA metadata for primitives with many args 2019-10-08 10:57:36 -07:00
James Bradbury
4e6385f8a6 remove source lines (caching makes them ~useless) 2019-10-04 11:59:06 -07:00
James Bradbury
59343b1a23 provide lax primitive + src line as XLA debuginfo 2019-10-03 17:56:25 -07:00
Matthew Johnson
98c7567a0d add flag for logging when jit performs compilation 2019-08-23 08:17:41 -07:00
Matthew Johnson
f56312bbff remove pat_fmap 2019-08-21 13:53:57 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
3c37a3260a Update linearize to no-tuple version 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
20fad746a8 De-dup equations with multiple lhs vars when creating a jaxpr
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Jamie Townsend
9b19afd4d9 Implement Jaxpr __hash__
This means that primitives like scatter, which have a Jaxpr in their
**params, will get cache hits appropriately.
2019-08-21 14:27:23 +01:00
Brian Patton
d07107af5a
Makes the Tracer object weakref-able 2019-08-16 17:18:44 -05:00
Peter Hawkins
6dc730a5f4 Make JAX tracer state thread-local. Allows performing traces in separate threads.
Using threading within a traced context still won't work, but that is perhaps less important than the ability to call JIT-ted computations from separate threads.

(Revives https://github.com/google/jax/pull/734.)
2019-08-09 13:55:20 -04:00
Jamie Townsend
21a69884fd call_wrapped in core.call_impl 2019-07-22 17:09:03 +01:00
Matthew Johnson
5aef18f897 improve literal hashing logic
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.
2019-06-19 10:32:55 -07:00
Matthew Johnson
b53bccc5d0 make more literals nontrivially hashable 2019-06-18 21:51:51 -07:00
Matthew Johnson
221426fadc de-duplicate constants staged into jaxprs
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-06-18 08:09:37 -07:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Matthew Johnson
310103f578 try a tweak on Literal for more cache hits 2019-05-28 22:50:52 -07:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
d27bc0a129 add literals to jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-13 08:48:13 -07:00
Matthew Johnson
6e9718a229 add pretty-printing to TypedJaxpr 2019-05-11 13:28:47 -07:00
Matthew Johnson
65202821df improve core.typed_jaxpr arg typechecks 2019-05-11 10:45:14 -07:00
Matthew Johnson
4fcd96f926 make tests pass with skip_checks = False 2019-05-10 22:07:54 -07:00
Matthew Johnson
29e67f0119 scan bug fixed, other cleanup 2019-05-10 15:52:12 -07:00
Matthew Johnson
5cfa18015c fix things we broke on the path to scan 2019-05-10 14:00:21 -07:00
Matthew Johnson
360e39756f must guarantee progress on lattice...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:38 -07:00
Matthew Johnson
085f06e4b6 add some PartialVal invariants
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:24 -07:00