128 Commits

Author SHA1 Message Date
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
9f0795b8f1 Unified the eager and jit paths
Added error checking for outfeed_receiver not started to primitive computations
2020-05-07 16:24:13 +03:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
804e083e66
Fix pytype for copybara import (#2995) 2020-05-07 13:28:24 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
Peter Hawkins
50dc44be6f
Fix IntEnum test when checking is enabled. (#2981) 2020-05-07 08:46:13 +03:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Jacob Kelly
cc0e9a3189
refactor ode tests, add scipy benchmark (#2824)
* refactor ode tests, add scipy benchmark

remove double import

rename to scipy merge vmap test properly

* clean up more global trace state after errors

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 21:53:38 -07:00
Matthew Johnson
89e3840e63
handle mapped_invars correctly in more places (#2828)
fixes #2822

We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
  5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional,
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
  3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.

This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
Matthew Johnson
8f902452ab
only maximally stage out for some call primitives (#2834)
fixes #2833
2020-04-24 18:19:24 -07:00
George Necula
a2c06d6113
Added clearer error message for tracers in numpy.split (#2508)
* Added clearer error message for tracers in numpy.split

Now we print:

ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>

* Fixed tests, slight change to the error message

* Expanded the FAQ entry about abstract tracers for higher-order primitives

* Added clarification for tracers inside jit of grad

* Updated FAQ language in response to reviews
2020-04-22 09:25:06 +02:00
Matthew Johnson
b1cb3a5dea
factor out process_map / post_process_map (#2788)
* factor out process_map / post_process_map

Also fix a bug from reusing post_process_call for pmap. Fixes #2787

* consolidate call_bind / map_bind code
2020-04-21 18:12:02 -07:00
Matthew Johnson
c9c14c02d4
Merge pull request #2602 from google/float-and-complex-builtins-error
make float and complex builtins error on Tracers
2020-04-09 07:31:23 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Matthew Johnson
7ab67756c8 make float and complex builtins error on Tracers
cf. #2508
2020-04-09 00:09:51 -07:00
Matthew Johnson
60de46a140
Merge pull request #2591 from google/tracer-printing
make tracers tree-pretty-print their contents
2020-04-03 15:47:41 -07:00
Matthew Johnson
297c90246d make tracers tree-pretty-print their contents 2020-04-02 21:04:12 -07:00
Matthew Johnson
5d3f1bdf4c tell mypy: using __init__ to reinitialize is OK 2020-04-02 20:14:12 -07:00
Matthew Johnson
6d4987cc04 make core.trace_state resetting be thread-local 2020-04-02 18:19:44 -07:00
Matthew Johnson
b78b7a0309 add global trace state checks to more tests 2020-04-02 18:03:58 -07:00
Matthew Johnson
e017a923a2 fix typo 2020-03-30 22:06:00 -07:00
Matthew Johnson
70a3f47bed comments/defaults for process_custom_{jv,vj}p_call 2020-03-30 12:02:25 -07:00
Matthew Johnson
6193e5e4dc revamp custom_jvp/vjp implementation to fix bugs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2020-03-29 19:35:01 -07:00
Matthew Johnson
f99720b70a add type annotations to core.py tracing machinery
also add .copy() method to core.trace_state global trace state
2020-03-28 14:58:35 -07:00
Matthew Johnson
74c20509eb improve custom_jvp error messages, fixes #2502 2020-03-24 21:45:50 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Matthew Johnson
1d0b7e2b5c make jaxpr pretty-print show multiple outputs 2020-03-19 11:26:29 -07:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Peter Hawkins
985d5f7327
Fix Python 3.5 support. (#2439)
* Fix Python 3.5 compatibility problems.
2020-03-17 17:01:04 -04:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
George Necula
282225f676
Added some pytype annotations (#2386)
Tried to catch all uses of linear_util.WrappedFun
2020-03-09 20:41:01 +01:00
Chris Jones
1e7d13b5f9
Give Vars an aval. (#2299) 2020-03-09 10:14:23 +01:00
George Necula
88677b1f67
Merge pull request #2233 from gnecula/bug_fix3
Expanded the error messages due to re-using tracers saved in global s…
2020-02-17 15:52:52 +01:00
Sharad Vikram
b92656db8b
Set call_p.multiple_results to True. 2020-02-14 23:29:33 -08:00
George Necula
deb21ef15d Expanded the error messages due to re-using tracers saved in global state.
Previously these errors were raising Exception (as other internal errors),
but these errors may arise out of mis-use of tracers.
2020-02-15 06:35:49 +01:00
George Necula
938336e08a
Merge pull request #2216 from gnecula/documentation
Added the first draft of the Jaxpr documentation.
2020-02-14 07:23:47 +01:00
Sharad Vikram
e93697461b
Make core.call_p a call primitive. (#2223) 2020-02-13 13:55:19 -08:00
George Necula
20dbc62277 Updated docstrings based on review comments 2020-02-13 09:28:01 +01:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
George Necula
20f9230f6e Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in params.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
2020-02-11 10:06:08 +01:00
George Necula
ae3003e9d4 Simplify bound_subjaxprs.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.

We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)

This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
2020-02-06 09:34:53 +01:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
1c134f8a6d
Rename Tracer.trace to Tracer._trace. (#2114)
Makes the .trace() method work on arrays.
2020-01-29 16:23:27 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
Matthew Johnson
07260f6572
remove hasing methods from core.Literal (#2038) 2020-01-22 17:19:14 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00