350 Commits

Author SHA1 Message Date
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
d8b75e1913 Reimplemented the passing of tokens with a Jaxpr transform 2020-05-07 16:24:13 +03:00
George Necula
304009d772 Added error checking when starting compiled computations without starting
the outfeed receiver.
2020-05-07 16:24:13 +03:00
George Necula
931cb3f684 Ensure that we carry state only for control-flow conditionals that use print 2020-05-07 16:24:13 +03:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
Tom Hennigan
4c2c5ad5f4
Add a note about jax.pmap when leading dim is smaller than num devices. (#2949) 2020-05-04 15:46:12 -07:00
George Necula
d315564ebf
Fixed a few more places where device commitment was lost. (#2913)
* trivial jit computations were forcing commitment to the default device
* a device_put with a device specification would not set the commitment
  if the data was already (uncommitted) on the specified device.
* added tests for the above
* once the above were fixed the LaztTest.test_zeros_ones_compilation
  stated to fail because the `sticky` parameter to lazy_force_computation
  was changing. Fixed this by removing stickyness from the compilation key.
* Expanded docstring for jax.device_put; expanded the
  device placement FAQ entry.
2020-05-04 11:30:28 +03:00
James Bradbury
1cdd8f1b99
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)
* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
2020-05-01 14:37:13 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00
Matthew Johnson
790d92965c
iterate on jax.hessian docs (#2873)
* iterate on jax.hessian docs

* tweaks

* add back note about block structure
2020-04-28 16:41:26 -07:00
Peter Hawkins
ae6a3fe9c1
Document how jax.hessian and pytrees interact. (#2705)
* Document how jax.hessian and pytrees interact.
2020-04-28 14:07:35 -04:00
Peter Hawkins
0dbbc27bb1
Clarify that grad requires arguments to be differentiated to be of inexact type. (#2712) 2020-04-28 11:58:51 -04:00
Adam Paszke
5fe6b069df
Correct the order of .format arguments in vjp wrapper (#2866) 2020-04-28 09:07:08 -04:00
Matthew Johnson
89e3840e63
handle mapped_invars correctly in more places (#2828)
fixes #2822

We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
  5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional,
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
  3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.

This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
Matthew Johnson
251834367f
use static_argnums in xla_computation (#2812)
* use static_argnums in xla_computation

fixes #1017

* add static_argnums to make_jaxpr

* fix type error: handle int case
2020-04-23 18:07:51 -07:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Lucas Beyer
685c0de99b
Fix confusing documentation typo. (#2773) 2020-04-20 11:37:05 -07:00
Peter Hawkins
7f6ede1b0a
Make type of value_and_grad slightly more precise. (#2704) 2020-04-13 20:52:55 -04:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Jake VanderPlas
f3a57656ed jit: raise TypeError if called on generator function 2020-04-07 13:55:35 +03:00
George Necula
7a5d1bc077 Expand docstring for vmap with details about out_axes, and improve error
checking

The newly added test cases used to raise the following kinds of exceptions:

AttributeError: 'float' object has no attribute 'shape'

ValueError: (0, None)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
    arg 0 has shape (2,) and axis None is to be mapped
        so

TypeError: only integer scalar arrays can be converted to a scalar index.
2020-04-06 12:11:29 +03:00
Matthew Johnson
83b9575145 add callable typechecks to more api.py functions 2020-03-31 18:46:15 -07:00
Matthew Johnson
e7be43da8a update api.py docstrings for sphinx highlighting 2020-03-30 21:09:12 -07:00
Matthew Johnson
f766c5e7b5 allow duck-typing in xla_computation arguments 2020-03-30 11:31:29 -07:00
Matthew Johnson
9d8823c912 add initial_style_staging to custom_transforms 2020-03-30 00:41:04 -07:00
Matthew Johnson
a6a837a65e add some stage_out=True indicators 2020-03-30 00:35:45 -07:00
Matthew Johnson
e7f8503c87
Merge pull request #2500 from google/custom-jvp-fix
revise custom_jvp / custom_vjp rule jaxpr staging
2020-03-29 23:04:00 -07:00
Matthew Johnson
6193e5e4dc revamp custom_jvp/vjp implementation to fix bugs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2020-03-29 19:35:01 -07:00
Matthew Johnson
1762a86531 workaround for pmap output PRED arrays on cpu/gpu 2020-03-29 14:45:17 -07:00
Matthew Johnson
3700fab73b remove deprecation warnings 2020-03-25 20:24:11 -07:00
Matthew Johnson
0cf84f925b fix custom_transforms bug 2020-03-23 18:43:02 -07:00
Matthew Johnson
83bf048f8a separate out deprecated custom_transforms stuff 2020-03-23 16:45:28 -07:00
Matthew Johnson
c76f32b1be remove jarrett and _make_graphviz, bitrot
might want to revive jarrett later!
2020-03-23 12:18:59 -07:00
George Necula
f658eb5bf5
Add back support for custom_transforms (#2484)
* add also the tests
* mark the old APIs as deprecated
2020-03-22 19:50:06 +01:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Peter Hawkins
578e5cf6d7
Fix return type for vjp. (#2462)
Fix vjp doc string.
2020-03-19 14:54:04 -04:00
Peter Hawkins
afdd1a7367
Add more return types to api.py. (#2452) 2020-03-19 10:28:29 -04:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Peter Hawkins
985d5f7327
Fix Python 3.5 support. (#2439)
* Fix Python 3.5 compatibility problems.
2020-03-17 17:01:04 -04:00
jheek
5f2d2254a0
Fix typo in ShapeDtypeStruct (#2253)
* fix ShapeDtypeSTruct dtype bug

* move dtype conversion to constructor
2020-03-17 07:45:17 +01:00
Matthew Johnson
efaedf4b57 undo previous commit 2020-03-16 12:13:25 -07:00
Matthew Johnson
5280793191 fix custom_transforms + jit bug from #2416 2020-03-16 10:23:24 -07:00
Matthew Johnson
a7b3be71e8 move jet into jax.experimental 2020-03-15 11:10:56 -07:00
Matthew Johnson
92a0b3d40a add basic pytree support to jet 2020-03-15 09:58:54 -07:00
Matthew Johnson
e84a621184 new jet implementation, with conv-based rules
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
Matthew Johnson
ebbcbad547
allow vmap in_axes to be a list, fixes #2367 (#2395) 2020-03-10 08:29:46 -07:00