143 Commits

Author SHA1 Message Date
Adam Paszke
d265fd5604 Ignore named shape when checking aval equality in AD
AD of code with named axes is still WIP, and pmap still doesn't take
proper care to handle them, so weaken the check for now.

PiperOrigin-RevId: 369265258
2021-04-19 11:35:34 -07:00
Matthew Johnson
9d6263a743 support implicit broadcasting in transpose rules 2021-04-16 12:51:11 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Matthew Johnson
632876d773 Copybara import of the project:
--
35fcf2e2fd5b4c56cbb591f4c8bf01222a23dfe5 by Matthew Johnson <mattjj@google.com>:

remove deprecated custom_transforms code

PiperOrigin-RevId: 366108489
2021-03-31 13:50:56 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
a261c768f2 simplify cotangent dropping in ad.backward_pass
Since variables in jaxprs are only assigned-to once, when we transpose
them we end up reading a variable's cotangent value only once. That
means we can pop the cotangent environment's reference to a cotangent
value in read_cotangent.
2021-03-24 16:25:50 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Matthew Johnson
786970130f add dynamic-shape-jaxpr experimental file
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2021-02-14 13:11:44 -08:00
Matthew Johnson
203af4517b revive the leak checker, as a debug mode
Co-authored-by: James Bradbury <jekbradbury@google.com>
2021-01-22 18:31:00 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
George Necula
81990b4047 Merge pull request #5243 from gnecula:print_refactor
PiperOrigin-RevId: 350362752
2021-01-06 18:07:39 +01:00
George Necula
aec6192ed9 Fix float0 handling 2021-01-06 14:15:42 +02:00
Jake VanderPlas
98aac23d92 Change from deflinear to deflinear2 2021-01-05 09:03:33 -08:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
Matthew Johnson
8fef9e8a56 clean up symbolic zero handling in transpose rules
We had a few rules in which:
1. transpose rules didn't correctly handle symbolic zero cotangent
inputs, and/or
2. transpose rules returned the ad_util.Zero class rather than an
instance of it (and/or returned a singleton where they should return a
list)

This solves a FIXME in the code left by @apaszke.

(Second attempt at this fix after #5030 was rolled back due to the div
transpose rule not being cleaned up yet, causing a new check I added to
fail.)
2020-11-28 10:19:04 -08:00
jax authors
dd9526bd39 Copybara import of the project:
--
cd81117b6a2899411646c45fa9c27b676fc47f86 by Matthew Johnson <mattjj@google.com>:

clean up symbolic zero handling in transpose rules

We had a few rules in which:
1. transpoes rules didn't correctly handle symbolic zero cotangent
inputs, and/or
2. transpose rules returned the ad_util.Zero class rather than an
instance of it (and/or returned a singleton where they should return a
list)

This solves a FIXME in the code left by @apaszke.

PiperOrigin-RevId: 344571259
2020-11-27 18:01:48 -08:00
Matthew Johnson
cd81117b6a clean up symbolic zero handling in transpose rules
We had a few rules in which:
1. transpoes rules didn't correctly handle symbolic zero cotangent
inputs, and/or
2. transpose rules returned the ad_util.Zero class rather than an
instance of it (and/or returned a singleton where they should return a
list)

This solves a FIXME in the code left by @apaszke.
2020-11-27 14:32:59 -08:00
Adam Paszke
5879967c25 Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.

One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.

* Implementation details *

This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.

** Thunking **

The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:

*** Transformations ***

Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
  outputs = yield args, kwargs
  yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
  old_out_axes = params['out_axes_thunk']()
  return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).

The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.

The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.

*** Compilation cache ***

Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.

Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.

* Why final style? *

Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-24 17:11:38 +00:00
Peter Hawkins
424594feb2 Short-circuit references to jax.core via jax.abstract_arrays. 2020-11-19 14:15:28 -05:00
jax authors
ec74c68c14 Merge pull request #4863 from apaszke:pmap-in-axes
PiperOrigin-RevId: 341715953
2020-11-10 15:58:22 -08:00
Adam Paszke
a5bc7353de Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 18:35:28 +00:00
jax authors
bdd7915661 Internal change
PiperOrigin-RevId: 341644256
2020-11-10 10:12:27 -08:00
Adam Paszke
6914058cbe Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 13:35:23 +00:00
Peter Hawkins
7efc1dbc94 [JAX] Move source_info_util into jax._src.
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.

PiperOrigin-RevId: 340698612
2020-11-04 11:54:24 -08:00
jax authors
4a20eea828 Copybara import of the project:
--
609f6f3e16d21fed34cc5269c54a0d78ac44a8bc by Matthew Johnson <mattjj@google.com>:

fix custom_jvp/vjp closure issues

PiperOrigin-RevId: 337457689
2020-10-16 00:21:32 -07:00
Matthew Johnson
8678287644 remove an extraneous replace_float0s
This caused a test failure when trying to land #4008.
2020-10-15 16:44:39 -07:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Lena Martens
e3d622cc67 Recast int/bool tangents to float0 in custom_jvp/vjps
(also in the initial_style path).
2020-10-08 17:37:35 +01:00
Lena Martens
cc0114a0a9 Fix dtype behavior with float0s in CustomVJP. 2020-10-01 15:17:51 +01:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Neil Girdhar
503e5973ce Make vjp cotangent functions pytree-like
Fixes #3667
2020-07-10 23:22:38 -04:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing (#3491) 2020-06-23 09:39:45 -07:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Matthew Johnson
fb8b3a4df9 symbolic zeros tweak to work with div rule 2020-06-08 13:16:19 -07:00
Matthew Johnson
982f86702e clean up handling of aux data in jvp_subtrace_aux
related to #3222 indirectly, in that we now won't try to do something
crazy like `JVPTracer(trace, object(), zero)`, which #3222 didn't like
2020-06-08 11:48:58 -07:00
Jake Vanderplas
528207bd80
fix flakes at head (#3361) 2020-06-08 10:45:00 -07:00
Matthew Johnson
508821cc18
fix a one-character issue from a bad merge (#3362)
cf. #3222
2020-06-08 09:47:32 -07:00
Adam Paszke
e2a32f741d Rebase fixes 2020-06-05 15:52:03 +00:00
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Adam Paszke
74d160f5e0
Don't keep primal arguments and results in the linearized jaxpr (#3233)
Linearized functions are supposed to take tangent types to tangent
types, and so all primal arguments are unused and primal results get
replaced by units.
2020-06-05 17:22:55 +02:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Adam Paszke
8f2d72eb40
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.

This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.

**Background on JVP vs linearization**

Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.

The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
                    partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.

One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!

**If all this is so nice, then what's the problem?**

The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...

One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.

I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!

What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.

Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.

Fin.
2020-05-27 20:22:40 +02:00
notEvil
969ed8085c
Add decorator for performing broadcasting inside translation rules (#2468)
* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-05-06 10:15:17 -04:00