146 Commits

Author SHA1 Message Date
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
Roy Frostig
c293a102b2 work around mypy 2020-05-21 20:54:02 -07:00
Roy Frostig
69d7bcf7fb except-and-raise during jaxpr checking, adding jaxpr as context, and simplify type environment 2020-05-21 20:02:30 -07:00
Roy Frostig
8e61ce8d1a fix unitvar comparisons and move to class attributes 2020-05-21 18:28:09 -07:00
Roy Frostig
7ff389bd03 extend type transfer to all primitives, including call and map primitives 2020-05-21 13:21:07 -07:00
Roy Frostig
e2cc568997 raise type errors consistently in jaxpr checker 2020-05-21 13:21:07 -07:00
Roy Frostig
1e55603344 avoid attempt to read literals from the typechecking environment 2020-05-21 13:21:07 -07:00
Roy Frostig
0f109d9fe0 add jaxpr context to typechecker error message 2020-05-21 13:21:07 -07:00
Roy Frostig
3705252be6 have UnitVar subclass Var (caught by mypy) 2020-05-21 13:21:07 -07:00
Roy Frostig
42e7e20eab update check_jaxpr doc 2020-05-21 13:21:07 -07:00
Roy Frostig
cc34ed2693 check aval compatibility, not strict equality, when typechecking jaxpr equations 2020-05-21 13:21:07 -07:00
Roy Frostig
0c2c558482 check that variables are typed equally throughout a jaxpr 2020-05-21 13:21:07 -07:00
Roy Frostig
8e70769cba factor out jaxpr-check context and variable environment 2020-05-21 13:21:07 -07:00
Roy Frostig
1205f7a00f factor out jaxpr equation checks 2020-05-21 13:21:07 -07:00
Roy Frostig
94b1f631ea raise TypeError for jaxpr typechecking errors 2020-05-21 13:21:07 -07:00
Roy Frostig
82a9af519a typecheck jaxpr equations 2020-05-21 13:21:07 -07:00
Matthew Johnson
a4094f72a4
revise "Tracer with raw numpy" error message (#3160)
* revise "Tracer with raw numpy" error message

fixes #3133

* fix f-string typo

* fix typo

Co-authored-by: James Bradbury <jekbradbury@google.com>

Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-20 19:09:44 -07:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
9f0795b8f1 Unified the eager and jit paths
Added error checking for outfeed_receiver not started to primitive computations
2020-05-07 16:24:13 +03:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
804e083e66
Fix pytype for copybara import (#2995) 2020-05-07 13:28:24 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
Peter Hawkins
50dc44be6f
Fix IntEnum test when checking is enabled. (#2981) 2020-05-07 08:46:13 +03:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Jacob Kelly
cc0e9a3189
refactor ode tests, add scipy benchmark (#2824)
* refactor ode tests, add scipy benchmark

remove double import

rename to scipy merge vmap test properly

* clean up more global trace state after errors

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 21:53:38 -07:00
Matthew Johnson
89e3840e63
handle mapped_invars correctly in more places (#2828)
fixes #2822

We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
  5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional,
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
  3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.

This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
Matthew Johnson
8f902452ab
only maximally stage out for some call primitives (#2834)
fixes #2833
2020-04-24 18:19:24 -07:00
George Necula
a2c06d6113
Added clearer error message for tracers in numpy.split (#2508)
* Added clearer error message for tracers in numpy.split

Now we print:

ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>

* Fixed tests, slight change to the error message

* Expanded the FAQ entry about abstract tracers for higher-order primitives

* Added clarification for tracers inside jit of grad

* Updated FAQ language in response to reviews
2020-04-22 09:25:06 +02:00
Matthew Johnson
b1cb3a5dea
factor out process_map / post_process_map (#2788)
* factor out process_map / post_process_map

Also fix a bug from reusing post_process_call for pmap. Fixes #2787

* consolidate call_bind / map_bind code
2020-04-21 18:12:02 -07:00
Matthew Johnson
c9c14c02d4
Merge pull request #2602 from google/float-and-complex-builtins-error
make float and complex builtins error on Tracers
2020-04-09 07:31:23 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Matthew Johnson
7ab67756c8 make float and complex builtins error on Tracers
cf. #2508
2020-04-09 00:09:51 -07:00
Matthew Johnson
60de46a140
Merge pull request #2591 from google/tracer-printing
make tracers tree-pretty-print their contents
2020-04-03 15:47:41 -07:00
Matthew Johnson
297c90246d make tracers tree-pretty-print their contents 2020-04-02 21:04:12 -07:00
Matthew Johnson
5d3f1bdf4c tell mypy: using __init__ to reinitialize is OK 2020-04-02 20:14:12 -07:00
Matthew Johnson
6d4987cc04 make core.trace_state resetting be thread-local 2020-04-02 18:19:44 -07:00
Matthew Johnson
b78b7a0309 add global trace state checks to more tests 2020-04-02 18:03:58 -07:00
Matthew Johnson
e017a923a2 fix typo 2020-03-30 22:06:00 -07:00
Matthew Johnson
70a3f47bed comments/defaults for process_custom_{jv,vj}p_call 2020-03-30 12:02:25 -07:00
Matthew Johnson
6193e5e4dc revamp custom_jvp/vjp implementation to fix bugs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2020-03-29 19:35:01 -07:00
Matthew Johnson
f99720b70a add type annotations to core.py tracing machinery
also add .copy() method to core.trace_state global trace state
2020-03-28 14:58:35 -07:00
Matthew Johnson
74c20509eb improve custom_jvp error messages, fixes #2502 2020-03-24 21:45:50 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Matthew Johnson
1d0b7e2b5c make jaxpr pretty-print show multiple outputs 2020-03-19 11:26:29 -07:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Peter Hawkins
985d5f7327
Fix Python 3.5 support. (#2439)
* Fix Python 3.5 compatibility problems.
2020-03-17 17:01:04 -04:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00