4444 Commits

Author SHA1 Message Date
Julius Kunze
02b4fd3500
Fix broadcast_shapes for polymorphic dims (#3216) (#3224)
* Fix #3216

* Simplify
2020-05-27 18:15:01 -04:00
Peter Hawkins
7d96aae579
Fix bug in pytype fix. (#3229) 2020-05-27 16:44:38 -04:00
Jake Vanderplas
41292a2af6
mention numpy & scipy convolve functions in gotchas doc. (#3214) 2020-05-27 13:20:27 -07:00
Peter Hawkins
336a0d6ea4
Fix pytype error. (#3226)
* Fix pytype error.

* Incorporate review comment.
2020-05-27 16:13:31 -04:00
Peter Hawkins
94b4ccd627
Relax test tolerance on lax_scipy_test to fix a test failure on Skylake machines at LLVM head. (#3225) 2020-05-27 15:34:35 -04:00
Roy Frostig
c5010cda47 use new gensym in host_callback jaxpr rewriter 2020-05-27 12:03:34 -07:00
Roy Frostig
da9d8c9c7e update reference jaxprs in tests 2020-05-27 12:03:34 -07:00
Roy Frostig
1b020fe2a9 update core.gensym consumers, address rewrite TODOs in lax control flow rules 2020-05-27 12:03:34 -07:00
Roy Frostig
e80e9634a7 jaxpr-dependent gensym to avoid var duplication 2020-05-27 12:03:34 -07:00
Adam Paszke
8f2d72eb40
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.

This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.

**Background on JVP vs linearization**

Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.

The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
                    partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.

One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!

**If all this is so nice, then what's the problem?**

The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...

One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.

I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!

What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.

Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.

Fin.
2020-05-27 20:22:40 +02:00
gaurav pathak
ec3b593ca8
Added geometric distribution to scipy stats (#3205) 2020-05-27 09:37:55 -07:00
Benjamin Kramer
e3b046bc45
Adjust complex64 tolerance for upcoming XLA change (#3218)
It makes sense for complex64 tolerance to be the same as float32
tolerance here. While there re-enable the test on GPU, which was blocked
on a bug that's long gone.
2020-05-27 10:29:00 -04:00
Lena Martens
1cc471928b
Remove pe from name_stack and test. (#3209) 2020-05-27 00:59:31 -07:00
Matthew Johnson
9f8a4ad341
remove stray print statement from #1529 2020-05-26 20:01:36 -07:00
Skye Wanderman-Milne
6ffde8061d
Implement pmap of sharded_jit (#3144)
* Implement pmap of sharded_jit

* Update jax/interpreters/pxla.py

Co-authored-by: James Bradbury <jekbradbury@google.com>

* Address comments

Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-26 14:26:53 -07:00
Du Phan
e526109a73
Remove dtype warning for np.quantile (#3188)
* drop the warning in index_to_gather

* fix dtype issue at quantile

* revert the change, the issue seems to be fixed
2020-05-26 15:41:01 -04:00
George van den Driessche
0f230029c9
Add a JAX flag to avoid most optimizations. (#3208) 2020-05-26 15:21:22 -04:00
Jean-Baptiste Lespiau
a486f54814
Add a summary explaining the usage and context for JAX PRNG design. (#2525)
* Add a summary explaining the usage and context for JAX PRNG design.

The current design_notes do not match current JAX API, and it's a pretty
long doc to read to understand how to use it.

Closes: #2087

* Change 'should' to be more precise.

* Address comments.
2020-05-26 10:38:28 +03:00
George Necula
f18f7920ba
Fix error in code generation of batched while loops (#3207)
Fixed the case when the value is a unit, which we do not batch.
2020-05-26 10:22:33 +03:00
James Bradbury
0eace80a6e
Fix experimental host callback on multi-host (#3200)
* Fix experimental host callback on multi-host

Hosts can only access the outfeed queue for local devices, while `api.devices` returns all devices in the system.

* Update host_callback.py
2020-05-25 08:12:58 +03:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
George Necula
afadb12b64
Improved tapping support for while: tap inside cond, vmap of while (#3195)
* Improved tapping support for while: tap inside cond, vmap of while

* Fix float64->float32 in tests
2020-05-24 10:50:07 +03:00
Stephan Hoyer
9d6744fc0c
Cleanup test_custom_root_scalar and re-enable it for TPUs (#3184)
The test passes now on TPUs, thanks to the new ``lax.integer_pow`` primitive.
2020-05-23 11:45:27 -07:00
George Necula
b493a7e5df
Fix the handling of repeated vmap for id_tap (#3132)
* Fix the handling of repeated vmap for id_tap

* Updated the transforms to always be a tuple of tuples

* Changed the transforms to be dictionaries
2020-05-23 13:49:27 +03:00
Penn
de800d2046
fix #3165 if round half up behavior is desired (#3166)
* fix issue 3165 if round half up behaviour is desired

* add test for round half

* fix integer array input and and add to test

* fix truncated integer output

* match input and output dtypes

* fix asymmetric rounding and extend test

* use lax for rounding
2020-05-22 15:41:37 -07:00
Jascha Sohl-Dickstein
190f88dede
Update Common_Gotchas_in_JAX.ipynb (#3189)
typo fix
2020-05-22 14:12:44 -07:00
Roy Frostig
96c20f3237
Merge pull request #2734 from google/tycheck
typecheck jaxprs
2020-05-21 22:07:24 -07:00
Roy Frostig
c293a102b2 work around mypy 2020-05-21 20:54:02 -07:00
Roy Frostig
69d7bcf7fb except-and-raise during jaxpr checking, adding jaxpr as context, and simplify type environment 2020-05-21 20:02:30 -07:00
Roy Frostig
8e61ce8d1a fix unitvar comparisons and move to class attributes 2020-05-21 18:28:09 -07:00
Skye Wanderman-Milne
ecd893626f Address comments 2020-05-21 14:50:16 -07:00
joao guilherme
77e4d8b3b9
Updates onp -> np in random, loops, jet and in the tests of stax and optix (#3182) 2020-05-21 14:12:18 -07:00
Skye Wanderman-Milne
d8ede0106a
Update jax/interpreters/pxla.py
Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-21 14:00:58 -07:00
Skye Wanderman-Milne
a3e0cd1293
Fix pxla.shard_args bug (#3170) 2020-05-21 13:52:03 -07:00
Roy Frostig
5d125539b6 axis_index abstract eval rule 2020-05-21 13:21:07 -07:00
Roy Frostig
1a91662654 return tuple in psum abstract eval rule 2020-05-21 13:21:07 -07:00
Roy Frostig
1d7808169b use jax.numpy in jaxpr typecheck tests 2020-05-21 13:21:07 -07:00
Roy Frostig
7ff389bd03 extend type transfer to all primitives, including call and map primitives 2020-05-21 13:21:07 -07:00
Roy Frostig
e2cc568997 raise type errors consistently in jaxpr checker 2020-05-21 13:21:07 -07:00
Roy Frostig
6475f60ce9 fix import in core_test 2020-05-21 13:21:07 -07:00
Roy Frostig
1e55603344 avoid attempt to read literals from the typechecking environment 2020-05-21 13:21:07 -07:00
Roy Frostig
060bd8a4f6 tidy jaxpr typechecking error test 2020-05-21 13:21:07 -07:00
Roy Frostig
0f109d9fe0 add jaxpr context to typechecker error message 2020-05-21 13:21:07 -07:00
Roy Frostig
3705252be6 have UnitVar subclass Var (caught by mypy) 2020-05-21 13:21:07 -07:00
Roy Frostig
42e7e20eab update check_jaxpr doc 2020-05-21 13:21:07 -07:00
Roy Frostig
cc34ed2693 check aval compatibility, not strict equality, when typechecking jaxpr equations 2020-05-21 13:21:07 -07:00
Roy Frostig
0c2c558482 check that variables are typed equally throughout a jaxpr 2020-05-21 13:21:07 -07:00
Roy Frostig
8e70769cba factor out jaxpr-check context and variable environment 2020-05-21 13:21:07 -07:00
Roy Frostig
1205f7a00f factor out jaxpr equation checks 2020-05-21 13:21:07 -07:00
Roy Frostig
94b1f631ea raise TypeError for jaxpr typechecking errors 2020-05-21 13:21:07 -07:00