153 Commits

Author SHA1 Message Date
Roy Frostig
dc4c9f0450 change cond primitive to an indexed conditional with multiple branch functions
in the core:

* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test

in experimental modules:

* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
2020-06-03 22:19:15 -07:00
Matthew Johnson
177e7cf311
moved check_jaxpr code around to match eval_jaxpr (#3240)
* moved check_jaxpr code around to match eval_jaxpr

This change is mostly stylistic; it brings check_jaxpr closer to
eval_jaxpr (and the other jaxpr interpreters) in organization. There's a
slight tweak to an error message which lets us save some slightly
redundant code.

* fixes and tweaks
2020-06-02 19:10:55 -07:00
Peter Hawkins
042df4ebff
Fix pytype errors. (#3291) 2020-06-02 10:26:43 -04:00
Peter Hawkins
34065df248
Add some type annotations to core and partial_eval. (#3251) 2020-06-01 21:45:36 -04:00
Matthew Johnson
49a441f745
revisions to #3197 (#3264)
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
2020-06-01 13:24:40 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Roy Frostig
e80e9634a7 jaxpr-dependent gensym to avoid var duplication 2020-05-27 12:03:34 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
Roy Frostig
c293a102b2 work around mypy 2020-05-21 20:54:02 -07:00
Roy Frostig
69d7bcf7fb except-and-raise during jaxpr checking, adding jaxpr as context, and simplify type environment 2020-05-21 20:02:30 -07:00
Roy Frostig
8e61ce8d1a fix unitvar comparisons and move to class attributes 2020-05-21 18:28:09 -07:00
Roy Frostig
7ff389bd03 extend type transfer to all primitives, including call and map primitives 2020-05-21 13:21:07 -07:00
Roy Frostig
e2cc568997 raise type errors consistently in jaxpr checker 2020-05-21 13:21:07 -07:00
Roy Frostig
1e55603344 avoid attempt to read literals from the typechecking environment 2020-05-21 13:21:07 -07:00
Roy Frostig
0f109d9fe0 add jaxpr context to typechecker error message 2020-05-21 13:21:07 -07:00
Roy Frostig
3705252be6 have UnitVar subclass Var (caught by mypy) 2020-05-21 13:21:07 -07:00
Roy Frostig
42e7e20eab update check_jaxpr doc 2020-05-21 13:21:07 -07:00
Roy Frostig
cc34ed2693 check aval compatibility, not strict equality, when typechecking jaxpr equations 2020-05-21 13:21:07 -07:00
Roy Frostig
0c2c558482 check that variables are typed equally throughout a jaxpr 2020-05-21 13:21:07 -07:00
Roy Frostig
8e70769cba factor out jaxpr-check context and variable environment 2020-05-21 13:21:07 -07:00
Roy Frostig
1205f7a00f factor out jaxpr equation checks 2020-05-21 13:21:07 -07:00
Roy Frostig
94b1f631ea raise TypeError for jaxpr typechecking errors 2020-05-21 13:21:07 -07:00
Roy Frostig
82a9af519a typecheck jaxpr equations 2020-05-21 13:21:07 -07:00
Matthew Johnson
a4094f72a4
revise "Tracer with raw numpy" error message (#3160)
* revise "Tracer with raw numpy" error message

fixes #3133

* fix f-string typo

* fix typo

Co-authored-by: James Bradbury <jekbradbury@google.com>

Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-20 19:09:44 -07:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
9f0795b8f1 Unified the eager and jit paths
Added error checking for outfeed_receiver not started to primitive computations
2020-05-07 16:24:13 +03:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
804e083e66
Fix pytype for copybara import (#2995) 2020-05-07 13:28:24 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
Peter Hawkins
50dc44be6f
Fix IntEnum test when checking is enabled. (#2981) 2020-05-07 08:46:13 +03:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Jacob Kelly
cc0e9a3189
refactor ode tests, add scipy benchmark (#2824)
* refactor ode tests, add scipy benchmark

remove double import

rename to scipy merge vmap test properly

* clean up more global trace state after errors

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 21:53:38 -07:00
Matthew Johnson
89e3840e63
handle mapped_invars correctly in more places (#2828)
fixes #2822

We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
  5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional,
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
  3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.

This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
Matthew Johnson
8f902452ab
only maximally stage out for some call primitives (#2834)
fixes #2833
2020-04-24 18:19:24 -07:00
George Necula
a2c06d6113
Added clearer error message for tracers in numpy.split (#2508)
* Added clearer error message for tracers in numpy.split

Now we print:

ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>

* Fixed tests, slight change to the error message

* Expanded the FAQ entry about abstract tracers for higher-order primitives

* Added clarification for tracers inside jit of grad

* Updated FAQ language in response to reviews
2020-04-22 09:25:06 +02:00
Matthew Johnson
b1cb3a5dea
factor out process_map / post_process_map (#2788)
* factor out process_map / post_process_map

Also fix a bug from reusing post_process_call for pmap. Fixes #2787

* consolidate call_bind / map_bind code
2020-04-21 18:12:02 -07:00
Matthew Johnson
c9c14c02d4
Merge pull request #2602 from google/float-and-complex-builtins-error
make float and complex builtins error on Tracers
2020-04-09 07:31:23 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Matthew Johnson
7ab67756c8 make float and complex builtins error on Tracers
cf. #2508
2020-04-09 00:09:51 -07:00
Matthew Johnson
60de46a140
Merge pull request #2591 from google/tracer-printing
make tracers tree-pretty-print their contents
2020-04-03 15:47:41 -07:00
Matthew Johnson
297c90246d make tracers tree-pretty-print their contents 2020-04-02 21:04:12 -07:00
Matthew Johnson
5d3f1bdf4c tell mypy: using __init__ to reinitialize is OK 2020-04-02 20:14:12 -07:00
Matthew Johnson
6d4987cc04 make core.trace_state resetting be thread-local 2020-04-02 18:19:44 -07:00
Matthew Johnson
b78b7a0309 add global trace state checks to more tests 2020-04-02 18:03:58 -07:00
Matthew Johnson
e017a923a2 fix typo 2020-03-30 22:06:00 -07:00
Matthew Johnson
70a3f47bed comments/defaults for process_custom_{jv,vj}p_call 2020-03-30 12:02:25 -07:00
Matthew Johnson
6193e5e4dc revamp custom_jvp/vjp implementation to fix bugs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2020-03-29 19:35:01 -07:00
Matthew Johnson
f99720b70a add type annotations to core.py tracing machinery
also add .copy() method to core.trace_state global trace state
2020-03-28 14:58:35 -07:00