Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
It makes sense for complex64 tolerance to be the same as float32
tolerance here. While there re-enable the test on GPU, which was blocked
on a bug that's long gone.
* Add a summary explaining the usage and context for JAX PRNG design.
The current design_notes do not match current JAX API, and it's a pretty
long doc to read to understand how to use it.
Closes: #2087
* Change 'should' to be more precise.
* Address comments.
* Fix experimental host callback on multi-host
Hosts can only access the outfeed queue for local devices, while `api.devices` returns all devices in the system.
* Update host_callback.py
* Added argument check to all primitives.
The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.
This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.
In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.
* Merged find_top_trace with check_args
This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
* fix issue 3165 if round half up behaviour is desired
* add test for round half
* fix integer array input and and add to test
* fix truncated integer output
* match input and output dtypes
* fix asymmetric rounding and extend test
* use lax for rounding