112 Commits

Author SHA1 Message Date
Matthew Johnson
60de46a140
Merge pull request #2591 from google/tracer-printing
make tracers tree-pretty-print their contents
2020-04-03 15:47:41 -07:00
Matthew Johnson
297c90246d make tracers tree-pretty-print their contents 2020-04-02 21:04:12 -07:00
Matthew Johnson
5d3f1bdf4c tell mypy: using __init__ to reinitialize is OK 2020-04-02 20:14:12 -07:00
Matthew Johnson
6d4987cc04 make core.trace_state resetting be thread-local 2020-04-02 18:19:44 -07:00
Matthew Johnson
b78b7a0309 add global trace state checks to more tests 2020-04-02 18:03:58 -07:00
Matthew Johnson
e017a923a2 fix typo 2020-03-30 22:06:00 -07:00
Matthew Johnson
70a3f47bed comments/defaults for process_custom_{jv,vj}p_call 2020-03-30 12:02:25 -07:00
Matthew Johnson
6193e5e4dc revamp custom_jvp/vjp implementation to fix bugs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2020-03-29 19:35:01 -07:00
Matthew Johnson
f99720b70a add type annotations to core.py tracing machinery
also add .copy() method to core.trace_state global trace state
2020-03-28 14:58:35 -07:00
Matthew Johnson
74c20509eb improve custom_jvp error messages, fixes #2502 2020-03-24 21:45:50 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Matthew Johnson
1d0b7e2b5c make jaxpr pretty-print show multiple outputs 2020-03-19 11:26:29 -07:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Peter Hawkins
985d5f7327
Fix Python 3.5 support. (#2439)
* Fix Python 3.5 compatibility problems.
2020-03-17 17:01:04 -04:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
George Necula
282225f676
Added some pytype annotations (#2386)
Tried to catch all uses of linear_util.WrappedFun
2020-03-09 20:41:01 +01:00
Chris Jones
1e7d13b5f9
Give Vars an aval. (#2299) 2020-03-09 10:14:23 +01:00
George Necula
88677b1f67
Merge pull request #2233 from gnecula/bug_fix3
Expanded the error messages due to re-using tracers saved in global s…
2020-02-17 15:52:52 +01:00
Sharad Vikram
b92656db8b
Set call_p.multiple_results to True. 2020-02-14 23:29:33 -08:00
George Necula
deb21ef15d Expanded the error messages due to re-using tracers saved in global state.
Previously these errors were raising Exception (as other internal errors),
but these errors may arise out of mis-use of tracers.
2020-02-15 06:35:49 +01:00
George Necula
938336e08a
Merge pull request #2216 from gnecula/documentation
Added the first draft of the Jaxpr documentation.
2020-02-14 07:23:47 +01:00
Sharad Vikram
e93697461b
Make core.call_p a call primitive. (#2223) 2020-02-13 13:55:19 -08:00
George Necula
20dbc62277 Updated docstrings based on review comments 2020-02-13 09:28:01 +01:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
George Necula
20f9230f6e Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in params.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
2020-02-11 10:06:08 +01:00
George Necula
ae3003e9d4 Simplify bound_subjaxprs.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.

We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)

This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
2020-02-06 09:34:53 +01:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
1c134f8a6d
Rename Tracer.trace to Tracer._trace. (#2114)
Makes the .trace() method work on arrays.
2020-01-29 16:23:27 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
Matthew Johnson
07260f6572
remove hasing methods from core.Literal (#2038) 2020-01-22 17:19:14 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Roy Frostig
afb8af19ff implement JVP of while loop
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-01-10 15:31:51 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Tom Hennigan
e2f7f37f59 Allow Var to be a key in a JAX tree. 2020-01-06 22:59:55 -08:00
George Necula
ea9e93282c Fix bug with caching in presence of JVP and JIT
The bug was that the auxiliary output of the process_env_traces
was mutated before the next cache hit, so the cache content changed.

Fixes: #1945
2020-01-05 16:37:27 +01:00
George van den Driessche
a73106b37c Avoid stack overflow when JITting a function that uses copy.copy or copy.deepcopy. (#1834) 2019-12-10 21:48:51 -05:00
Matthew Johnson
7083b0a78e roll back previous commit #1829
There was a mysterious failure on an internal test, and that
mysteriousness means I didn't fully understand the attempted fix, so
best to roll back for now.
2019-12-06 22:28:57 -08:00
Matthew Johnson
80f455d3f0 make eval_jaxpr get jit cache hits 2019-12-06 21:10:52 -08:00
George Necula
2b0b04fcad Merge remote-tracking branch 'upstream/master' into jaxpr_pp 2019-11-28 08:56:00 +01:00
George Necula
0cb3b433b5 Change in how we print sorted params for eqns 2019-11-28 07:34:40 +01:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
2bb74b627e Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output 2019-11-26 14:05:08 +01:00
Matthew Johnson
1fcebbaa0e fix reference cycle in jaxpr tracing using weakrefs
As one step in tracing user code to a jaxpr using the machinery in
partial_eval.py, we construct a bipartite graph made of JaxprTracer
nodes, corresponding to values in the user code, and recipe nodes
,particularly those corresponding to jaxpr equations, representing
primitive operations. (This representation was put in place in #1224,
since when primitives only had single outputs we could identify each
primitive operation with the JaxprTracer value it produced.) This graph
had reference cycles because each equation recipe points to both its
input and output tracers (as a jaxpr eqn has both input and output vars)
and a tracer must be able to point to the equation recipe that produced
it (for us to toposort the graph from in_tracers to out_tracers in
tracers_to_jaxpr).

Those cycles caused memory leaks. This commit removes the strong
reference cycle using weakrefs. In particular, equation recipes only
hold weak references to their output tracers.

Before this change, we used the core.JaxprEqn struct both to represent
equations in jaxprs (where invars and outvars are instances of the
core.Var class) and to represent equation recipes (where invars and
outvars are instances of the partial_eval.JaxprTracer class). That was a
bit lazy. This commit distinguishes the two as separate JaxprEqn and
JaxprEqnRecipe structs.

Bug find and test code from @trevorcai. Thanks!
2019-11-19 15:23:08 -08:00
Peter Hawkins
5c3b99d0b4
Implement the __pos__ operator on JAX arrays. (#1718) 2019-11-18 22:00:32 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
James Bradbury
bc0e79767b fix XLA metadata for primitives with many args 2019-10-08 10:57:36 -07:00
James Bradbury
4e6385f8a6 remove source lines (caching makes them ~useless) 2019-10-04 11:59:06 -07:00