2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import functools
|
2019-02-23 20:34:14 -08:00
|
|
|
import itertools as it
|
2020-06-08 10:45:00 -07:00
|
|
|
from typing import Any, Callable, Dict, Set, List
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import partial_eval as pe
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..config import config
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
from .. import core
|
2020-09-24 16:29:57 +01:00
|
|
|
from ..dtypes import dtype, float0
|
2020-11-18 21:17:02 -05:00
|
|
|
from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
|
|
|
raise_to_shaped)
|
2019-05-07 08:52:08 -07:00
|
|
|
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
2020-05-27 13:57:47 +00:00
|
|
|
zeros_like_p, Zero)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
from ..util import (unzip2, safe_map, safe_zip, partial, split_list, wrap_name,
|
|
|
|
as_hashable_function)
|
2020-01-05 04:35:34 +01:00
|
|
|
from ..tree_util import register_pytree_node
|
|
|
|
from .. import linear_util as lu
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
2020-07-09 14:13:45 -04:00
|
|
|
from ..tree_util import tree_flatten, tree_unflatten, Partial
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
zip = safe_zip
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(jvp_subtrace(fun), instantiate)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, aux = jvp_subtrace_aux(fun)
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(fun, instantiate), aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def jvpfun(instantiate, primals, tangents):
|
2020-09-24 16:29:57 +01:00
|
|
|
tangents = [Zero.from_value(t) if not isinstance(t, Zero)
|
|
|
|
and dtype(t) is float0 else t for t in tangents]
|
2020-08-30 12:38:14 +03:00
|
|
|
with core.new_main(JVPTrace) as main:
|
|
|
|
out_primals, out_tangents = yield (main, primals, tangents), {}
|
|
|
|
del main
|
2019-07-27 15:46:14 -07:00
|
|
|
if type(instantiate) is bool:
|
|
|
|
instantiate = [instantiate] * len(out_tangents)
|
2020-05-28 13:20:56 +00:00
|
|
|
out_tangents = [instantiate_zeros(t) if inst else t for t, inst
|
|
|
|
in zip(out_tangents, instantiate)]
|
2019-07-27 15:46:14 -07:00
|
|
|
yield out_primals, out_tangents
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2020-08-30 12:38:14 +03:00
|
|
|
def jvp_subtrace(main, primals, tangents):
|
|
|
|
trace = JVPTrace(main, core.cur_sublevel())
|
2018-11-17 18:03:33 -08:00
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
2020-01-29 16:23:27 -05:00
|
|
|
assert x._trace.level < trace.level
|
2020-05-27 13:57:47 +00:00
|
|
|
in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x
|
2019-09-09 17:47:15 -07:00
|
|
|
for x, t in zip(primals, tangents)]
|
|
|
|
ans = yield in_tracers, {}
|
2019-07-26 23:17:21 -04:00
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
yield unzip2([(out_tracer.primal, out_tracer.tangent)
|
|
|
|
for out_tracer in out_tracers])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-08-30 12:38:14 +03:00
|
|
|
def jvp_subtrace_aux(main, primals, tangents):
|
|
|
|
trace = JVPTrace(main, core.cur_sublevel())
|
2019-03-07 14:08:02 -08:00
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
2020-01-29 16:23:27 -05:00
|
|
|
assert x._trace.level < trace.level
|
2019-04-10 22:09:14 -07:00
|
|
|
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
ans_tracers = map(trace.full_raise, ans)
|
|
|
|
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
|
2020-06-08 11:48:58 -07:00
|
|
|
aux_primals = [core.full_lower(x.primal)
|
|
|
|
if isinstance(x, JVPTracer) and x._trace.level == trace.level
|
|
|
|
else x for x in aux]
|
2019-07-27 15:46:14 -07:00
|
|
|
yield (out_primals, out_tangents), aux_primals
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
def linearize(traceable, *primals, **kwargs):
|
|
|
|
has_aux = kwargs.pop('has_aux', False)
|
|
|
|
if not has_aux:
|
2019-07-26 23:17:21 -04:00
|
|
|
jvpfun = jvp(traceable)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
|
|
|
jvpfun, aux = jvp(traceable, has_aux=True)
|
2019-07-26 23:17:21 -04:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
|
2020-09-24 16:29:57 +01:00
|
|
|
+ tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
|
2019-07-26 23:17:21 -04:00
|
|
|
for p in primals))
|
|
|
|
_, in_tree = tree_flatten(((primals, primals), {}))
|
|
|
|
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|
|
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
|
2020-03-18 07:11:44 +01:00
|
|
|
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
|
|
|
|
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
|
|
|
|
_, out_primals_consts = unzip2(out_primals_pvals)
|
2020-06-05 17:22:55 +02:00
|
|
|
jaxpr.invars = jaxpr.invars[len(primals):]
|
|
|
|
jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):]
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2020-03-18 07:11:44 +01:00
|
|
|
return out_primals_consts, out_tangents_pvals, jaxpr, consts
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def vjp(traceable, primals, has_aux=False):
|
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
|
2020-07-09 14:13:45 -04:00
|
|
|
|
|
|
|
def unbound_vjp(pvals, jaxpr, consts, *cts):
|
2019-07-27 15:46:14 -07:00
|
|
|
cts = tuple(map(ignore_consts, cts, pvals))
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
2020-06-05 17:22:55 +02:00
|
|
|
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
|
2020-05-28 13:20:56 +00:00
|
|
|
return map(instantiate_zeros, arg_cts)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-09 14:13:45 -04:00
|
|
|
# Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
|
|
|
|
# pass in a custom VJP.
|
|
|
|
vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts)
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_, aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-03 22:24:46 -05:00
|
|
|
def ignore_consts(ct, pval):
|
|
|
|
aval, const = pval
|
|
|
|
if isinstance(aval, core.AbstractValue):
|
|
|
|
return ct
|
|
|
|
elif aval is None:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
def unpair_pval(pval):
|
|
|
|
aval, const = pval
|
|
|
|
const_1, const_2 = const
|
|
|
|
if aval is None:
|
|
|
|
return (None, const_1), (None, const_2)
|
|
|
|
else:
|
|
|
|
aval_1, aval_2 = aval
|
|
|
|
return (aval_1, const_1), (aval_2, const_2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def replace_float0s(primal, tangent):
|
|
|
|
if dtype(tangent) is float0:
|
|
|
|
return core.zeros_like_float0(tangent, dtype(primal))
|
|
|
|
else:
|
|
|
|
return tangent
|
2020-10-08 15:36:05 +01:00
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def recast_to_float0(primal, tangent):
|
|
|
|
if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
|
|
|
|
return Zero(get_aval(primal).at_least_vspace())
|
|
|
|
else:
|
|
|
|
return tangent
|
2020-10-08 15:36:05 +01:00
|
|
|
|
2020-04-17 11:20:54 +00:00
|
|
|
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
|
|
|
|
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
2020-05-27 13:57:47 +00:00
|
|
|
if all(type(ct) is Zero for ct in cotangents_in):
|
|
|
|
return map(lambda v: Zero(v.aval), jaxpr.invars)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-11-28 09:13:21 -08:00
|
|
|
def write_cotangent(prim, v, ct):
|
2018-11-17 18:03:33 -08:00
|
|
|
# assert v not in primal_env
|
2020-11-28 09:13:21 -08:00
|
|
|
assert ct is not Zero, (prim, v.aval) # check for an old harmless type error
|
2020-11-05 11:54:05 +00:00
|
|
|
if ct is None or type(v) is Literal:
|
|
|
|
return
|
|
|
|
if type(ct) is Zero:
|
|
|
|
# FIXME: This triggers a lot of failures!
|
2020-11-28 09:13:21 -08:00
|
|
|
# assert v.aval == ct.aval, (prim, v.aval, ct.aval)
|
2020-11-27 18:01:22 -08:00
|
|
|
return
|
2020-11-05 11:54:05 +00:00
|
|
|
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
|
|
|
if not core.skip_checks:
|
|
|
|
ct_aval = core.get_aval(ct_env[v])
|
2020-11-28 09:13:21 -08:00
|
|
|
joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
|
|
|
|
assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def read_cotangent(v):
|
2020-05-27 13:57:47 +00:00
|
|
|
return ct_env.get(v, Zero(v.aval))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def read_primal(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return primal_env.get(v, UndefinedPrimal(v.aval))
|
2019-04-25 10:43:50 -07:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def write_primal(v, val):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if not is_undefined_primal(val):
|
2019-05-01 15:47:01 -07:00
|
|
|
primal_env[v] = val
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primal_env: Dict[Any, Any] = {}
|
2019-11-22 10:53:11 -08:00
|
|
|
write_primal(core.unitvar, core.unit)
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_primal, jaxpr.constvars, consts)
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: invars can contain both primal and tangent values, and this line
|
|
|
|
# forces primal_in to contain UndefinedPrimals for tangent values!
|
|
|
|
map(write_primal, jaxpr.invars, primals_in)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-20 12:24:05 +02:00
|
|
|
# Find the last use of each cotangent so that they can be removed
|
|
|
|
# as soon as possible.
|
|
|
|
drop_cts: List[Set[Any]] = []
|
|
|
|
seen_vars: Set[Any] = set(jaxpr.invars)
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
for eqn in jaxpr.eqns:
|
2020-04-20 12:24:05 +02:00
|
|
|
read_set = set(eqn.outvars) # NOTE: eqn is not transposed yet!
|
|
|
|
drop_cts.append(read_set - seen_vars)
|
|
|
|
seen_vars |= read_set
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
ct_env: Dict[Any, Any] = {}
|
2020-11-28 09:13:21 -08:00
|
|
|
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
for eqn, to_drop in zip(jaxpr.eqns[::-1], drop_cts[::-1]):
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: Some invars correspond to tangents
|
2019-07-27 15:46:14 -07:00
|
|
|
invals = map(read_primal, eqn.invars)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
cts_in = map(read_cotangent, eqn.outvars)
|
2019-04-25 10:43:50 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_in, = map(read_cotangent, eqn.outvars)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
with source_info_util.user_context(eqn.source_info):
|
|
|
|
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
|
|
|
|
cts_in_avals = [v.aval for v in eqn.outvars]
|
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(
|
|
|
|
params, call_jaxpr, invals, cts_in, cts_in_avals)
|
|
|
|
else:
|
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
|
|
|
|
**eqn.params)
|
2020-06-08 13:16:19 -07:00
|
|
|
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: Some invars correspond to primals!
|
2020-11-28 09:13:21 -08:00
|
|
|
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|
2020-04-20 12:24:05 +02:00
|
|
|
for var in to_drop:
|
|
|
|
ct_env.pop(var, None) # NB: Constant cotangents might be missing
|
2018-12-03 22:24:46 -05:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
cotangents_out = map(read_cotangent, jaxpr.invars)
|
2020-01-07 13:11:32 -08:00
|
|
|
return cotangents_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
class UndefinedPrimal:
|
|
|
|
__slots__ = ['aval']
|
|
|
|
def __init__(self, aval):
|
|
|
|
self.aval = aval
|
|
|
|
def __repr__(self):
|
|
|
|
return 'UndefinedPrimal({})'.format(self.aval)
|
|
|
|
|
|
|
|
def is_undefined_primal(x):
|
|
|
|
return type(x) is UndefinedPrimal
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
register_pytree_node(UndefinedPrimal,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
lambda z: ((), z.aval),
|
|
|
|
lambda aval, _: UndefinedPrimal(aval))
|
2019-05-07 08:52:08 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def get_primitive_transpose(p):
|
|
|
|
try:
|
|
|
|
return primitive_transposes[p]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise NotImplementedError(
|
2020-01-15 15:00:38 -08:00
|
|
|
"Transpose rule (for reverse-mode differentiation) for '{}' "
|
|
|
|
"not implemented".format(p)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-12-02 14:13:05 +00:00
|
|
|
@lu.transformation_with_aux
|
|
|
|
def nonzero_tangent_outputs(*args, **kwargs):
|
|
|
|
results = (_, tangents_out) = yield args, kwargs
|
|
|
|
yield results, [type(r) is not Zero for r in tangents_out]
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JVPTrace(Trace):
|
|
|
|
|
|
|
|
def pure(self, val):
|
2020-09-24 16:29:57 +01:00
|
|
|
tangent_zero = Zero(get_aval(val).at_least_vspace())
|
|
|
|
return JVPTracer(self, val, tangent_zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, val):
|
2020-09-24 16:29:57 +01:00
|
|
|
tangent_zero = Zero(get_aval(val).at_least_vspace())
|
|
|
|
return JVPTracer(self, val, tangent_zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-24 16:08:03 -08:00
|
|
|
def sublift(self, val):
|
2018-11-17 18:03:33 -08:00
|
|
|
return JVPTracer(self, val.primal, val.tangent)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
2020-06-23 09:39:45 -07:00
|
|
|
jvp = primitive_jvps.get(primitive)
|
|
|
|
if not jvp:
|
|
|
|
msg = f"Differentiation rule for '{primitive}' not implemented"
|
|
|
|
raise NotImplementedError(msg)
|
2018-11-17 18:03:33 -08:00
|
|
|
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
2019-07-27 15:46:14 -07:00
|
|
|
if primitive.multiple_results:
|
|
|
|
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
|
|
|
|
else:
|
|
|
|
return JVPTracer(self, primal_out, tangent_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
2020-06-23 09:39:45 -07:00
|
|
|
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
|
|
|
|
nz_tangents = [type(t) is not Zero for t in tangents]
|
|
|
|
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
2020-12-02 14:13:05 +00:00
|
|
|
f_jvp = jvp_subtrace(f, self.main)
|
2020-06-23 09:39:45 -07:00
|
|
|
if isinstance(call_primitive, core.MapPrimitive):
|
2020-11-05 11:54:05 +00:00
|
|
|
in_axes = params['in_axes']
|
|
|
|
tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz]
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_axes_thunk = params['out_axes_thunk']
|
2020-12-02 14:13:05 +00:00
|
|
|
f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
# The new thunk depends deterministically on the old thunk and the wrapped function.
|
|
|
|
# Any caching already has to include the wrapped function as part of the key, so we
|
|
|
|
# only use the previous thunk for equality checks.
|
2020-12-02 14:13:05 +00:00
|
|
|
# NOTE: This assumes that the output tangents being zero is a deterministic
|
|
|
|
# function of which input tangents were zero.
|
|
|
|
@as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def new_out_axes_thunk():
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz))
|
|
|
|
params = dict(params,
|
|
|
|
in_axes=(*in_axes, *tangent_in_axes),
|
|
|
|
out_axes_thunk=new_out_axes_thunk)
|
2020-12-02 14:13:05 +00:00
|
|
|
f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
|
2020-06-23 09:39:45 -07:00
|
|
|
update_params = call_param_updaters.get(call_primitive)
|
|
|
|
new_params = update_params(params, nz_tangents) if update_params else params
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
|
2020-03-28 14:15:46 -07:00
|
|
|
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
|
|
|
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
2020-06-23 09:39:45 -07:00
|
|
|
out, treedef = tree_flatten((primals, tangents))
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
tangents_nz = [type(t) is not Zero for t in tangents]
|
2019-07-27 15:46:14 -07:00
|
|
|
del primals, tangents
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
2018-11-17 18:03:33 -08:00
|
|
|
def todo(x):
|
2020-06-23 09:39:45 -07:00
|
|
|
primals, tangents = tree_unflatten(treedef, x)
|
2020-08-30 12:38:14 +03:00
|
|
|
trace = JVPTrace(main, core.cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(partial(JVPTracer, trace), primals, tangents)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
if call_primitive.map_primitive:
|
|
|
|
def out_axes_transform(out_axes):
|
|
|
|
return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz))
|
|
|
|
todo = (todo, out_axes_transform)
|
2019-07-27 15:46:14 -07:00
|
|
|
return out, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
# The only difference between process_map and process_call is that
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
# the `in_axes` and `out_axes_thunk` params must be updated;
|
|
|
|
# that's handled in process_call.
|
2020-06-23 09:39:45 -07:00
|
|
|
process_map = process_call
|
2020-04-21 18:12:02 -07:00
|
|
|
post_process_map = post_process_call
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_jvp_call(self, _, __, f_jvp, tracers):
|
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
|
|
|
primals_in = map(core.full_lower, primals_in)
|
2020-05-28 13:20:56 +00:00
|
|
|
tangents_in = map(instantiate_zeros, tangents_in)
|
2020-10-08 15:36:05 +01:00
|
|
|
# Cast float0 to zeros with the primal dtype because custom jvp rules don't
|
2020-09-24 16:29:57 +01:00
|
|
|
# currently handle float0s
|
2020-10-16 00:21:04 -07:00
|
|
|
tangents_in = map(replace_float0s, primals_in, tangents_in)
|
2020-03-28 14:15:46 -07:00
|
|
|
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
|
|
|
|
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
2020-10-16 00:21:04 -07:00
|
|
|
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
2020-03-28 14:15:46 -07:00
|
|
|
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def post_process_custom_jvp_call(self, out_tracers, params):
|
|
|
|
raise CustomJVPException()
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
|
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
2020-05-28 13:20:56 +00:00
|
|
|
tangents_in = map(instantiate_zeros, tangents_in)
|
2020-03-28 14:15:46 -07:00
|
|
|
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
|
|
|
|
out_tree, res_tree = out_trees()
|
|
|
|
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
|
|
|
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
|
|
|
tangents_out = custom_lin_p.bind(
|
|
|
|
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
|
|
|
avals_out=avals_out)
|
2020-10-16 00:21:04 -07:00
|
|
|
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
2020-03-28 14:15:46 -07:00
|
|
|
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def post_process_custom_vjp_call(self, out_tracers, params):
|
|
|
|
raise CustomVJPException()
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def join(self, xt, yt):
|
2020-05-27 13:57:47 +00:00
|
|
|
xz, yz = type(xt) is Zero, type(yt) is Zero
|
2019-07-27 15:46:14 -07:00
|
|
|
if xz == yz:
|
2018-11-17 18:03:33 -08:00
|
|
|
return xt, yt
|
2019-07-27 15:46:14 -07:00
|
|
|
elif yz and not xz:
|
|
|
|
return xt, zeros_like_jaxval(xt)
|
|
|
|
elif xz and not yz:
|
|
|
|
return zeros_like_jaxval(yt), yt
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError((xt, yt))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
class JVPTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __init__(self, trace, primal, tangent):
|
2019-05-10 15:52:12 -07:00
|
|
|
if not core.skip_checks:
|
|
|
|
_primal_tangent_shapes_match(primal, tangent)
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.primal = primal
|
|
|
|
self.tangent = tangent
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
# TODO(dougalm): add epsilon ball
|
|
|
|
return get_aval(self.primal)
|
|
|
|
|
|
|
|
def full_lower(self):
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(self.tangent) is Zero:
|
2018-11-17 18:03:33 -08:00
|
|
|
return core.full_lower(self.primal)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2019-05-10 15:52:12 -07:00
|
|
|
def _primal_tangent_shapes_match(primal, tangent):
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(tangent) is not Zero:
|
2020-10-07 11:41:22 -07:00
|
|
|
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
|
|
|
|
tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
|
2020-09-24 16:29:57 +01:00
|
|
|
assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape)
|
|
|
|
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
|
|
|
|
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
|
2020-06-23 09:39:45 -07:00
|
|
|
|
|
|
|
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
|
|
|
call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
|
2019-05-10 15:52:12 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
# -------------------- Primitives --------------------
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
primitive_jvps : Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primitive_transposes: Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def deflinear(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
|
|
|
|
|
|
|
def linear_jvp(primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
2020-05-27 13:57:47 +00:00
|
|
|
if all(type(tangent) is Zero for tangent in tangents):
|
|
|
|
return val_out, Zero.from_value(val_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-05-28 13:20:56 +00:00
|
|
|
tangents = map(instantiate_zeros, tangents)
|
2018-11-17 18:03:33 -08:00
|
|
|
return val_out, primitive.bind(*tangents, **params)
|
|
|
|
|
|
|
|
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
2020-05-27 13:57:47 +00:00
|
|
|
return Zero if type(cotangent) is Zero else transpose_rule(cotangent, **kwargs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def deflinear2(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
|
|
|
|
|
|
|
|
def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
|
2020-05-27 13:57:47 +00:00
|
|
|
return Zero if type(cotangent) is Zero else transpose_rule(cotangent, *args, **kwargs)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def defjvp(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
2020-05-27 13:57:47 +00:00
|
|
|
assert not primitive.multiple_results
|
2018-11-17 18:03:33 -08:00
|
|
|
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
|
|
|
|
|
|
|
|
|
|
|
|
def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
2019-02-20 12:36:18 -08:00
|
|
|
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
|
2020-05-27 13:57:47 +00:00
|
|
|
if rule is not None and type(t) is not Zero]
|
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defjvp2(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
2020-05-27 13:57:47 +00:00
|
|
|
assert not primitive.multiple_results
|
2018-11-17 18:03:33 -08:00
|
|
|
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
|
|
|
|
|
|
|
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
2020-05-27 13:57:47 +00:00
|
|
|
if rule is not None and type(t) is not Zero)
|
|
|
|
tangents_out = list(tangents_out)
|
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def add_tangents(x, y):
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(x) is Zero:
|
2018-11-17 18:03:33 -08:00
|
|
|
return y
|
2020-05-27 13:57:47 +00:00
|
|
|
elif type(y) is Zero:
|
2018-11-17 18:03:33 -08:00
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return add_jaxvals(x, y)
|
|
|
|
|
|
|
|
|
|
|
|
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
|
|
|
assert isinstance(prim, Primitive)
|
|
|
|
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
|
|
|
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
|
|
|
|
defjvp(prim, lhs_jvp, rhs_jvp)
|
|
|
|
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
|
|
|
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
|
|
|
|
|
|
|
|
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert is_undefined_primal(x) ^ is_undefined_primal(y)
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(cotangent) is Zero:
|
|
|
|
return Zero
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if is_undefined_primal(x):
|
2020-05-27 13:57:47 +00:00
|
|
|
out = lhs_rule(cotangent, y, **kwargs)
|
|
|
|
return Zero if out is Zero else (out, None)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-05-27 13:57:47 +00:00
|
|
|
out = rhs_rule(cotangent, x, **kwargs)
|
|
|
|
return Zero if out is Zero else (None, out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def defjvp_zero(primitive):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
|
|
|
|
|
|
|
def zero_jvp(primitive, primals, tangents, **params):
|
2020-05-27 13:57:47 +00:00
|
|
|
r = primitive.bind(*primals, **params)
|
|
|
|
return r, Zero.from_value(r)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-12-30 17:42:04 -08:00
|
|
|
deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
|
|
|
|
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-28 13:20:56 +00:00
|
|
|
def instantiate_zeros(tangent):
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(tangent) is Zero:
|
2020-09-24 16:29:57 +01:00
|
|
|
if isinstance(tangent.aval, Tracer):
|
|
|
|
return tangent.aval
|
2020-05-28 13:20:56 +00:00
|
|
|
return zeros_like_aval(tangent.aval)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2021-01-06 13:36:37 +02:00
|
|
|
# This function seems similar to instantiate_zeros, but it is sometimes used
|
2020-05-28 13:20:56 +00:00
|
|
|
# to instantiate zero abstract units with a different aval
|
2019-05-07 08:52:08 -07:00
|
|
|
def instantiate_zeros_aval(aval, tangent):
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(tangent) is Zero:
|
2020-05-28 13:20:56 +00:00
|
|
|
assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval
|
2019-05-07 08:52:08 -07:00
|
|
|
return zeros_like_aval(aval)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
|
|
|
new_primals = primals_and_tangents[:num_primals]
|
|
|
|
new_tangents = primals_and_tangents[num_primals:]
|
|
|
|
new_tangents = tree_unflatten(in_tree_def, new_tangents)
|
2019-04-10 22:09:14 -07:00
|
|
|
primal_out, tangent_out = yield (new_primals, new_tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
|
|
|
|
yield out_flat, tree_def
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
2020-02-05 11:08:21 +01:00
|
|
|
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
2020-02-05 15:38:25 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
2020-06-23 09:39:45 -07:00
|
|
|
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
|
|
|
update_params = call_transpose_param_updaters.get(primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, map(is_undefined_primal, args),
|
|
|
|
[type(x) is not Zero for x in ct])
|
|
|
|
out_flat = primitive.bind(fun, *all_args, **new_params)
|
2019-07-27 15:46:14 -07:00
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
|
|
|
|
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals):
|
|
|
|
# backward_pass can only transpose linear computations, but the call_jaxpr embedded in
|
|
|
|
# remat contains primal (non-linear) equations too. Hence, we have to eliminate those
|
|
|
|
# (in this case via partial_eval) before we call into backward_pass again.
|
2020-09-18 10:07:13 -07:00
|
|
|
typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
|
2020-06-23 09:39:45 -07:00
|
|
|
unknowns = map(is_undefined_primal, primals_in)
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
|
|
|
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore
|
|
|
|
else:
|
|
|
|
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
|
|
|
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
|
2020-09-15 08:06:46 -07:00
|
|
|
trace_type=None) # type: ignore
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
|
|
|
|
def do_transpose(primals_in, cotangents_in):
|
|
|
|
# NOTE: This is passing in undefined primals in place of tangent arguments, but it
|
|
|
|
# should all work out, because we're only computing the primal part here.
|
|
|
|
residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
|
|
|
|
# Now that we have a purely linear jaxpr, we can transpose it
|
|
|
|
cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in)
|
|
|
|
# backward_pass will return cotangents computed for all invars, but some of them
|
|
|
|
# are residuals appended by partial eval, so we need to skip those before we return.
|
|
|
|
return cotangents_out[:len(primals_in)]
|
|
|
|
|
|
|
|
flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
|
|
|
|
flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def)
|
|
|
|
flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params)
|
|
|
|
return tree_unflatten(out_tree(), flat_cotangents_out)
|
|
|
|
primitive_transposes[pe.remat_call_p] = remat_transpose
|
|
|
|
|
2020-12-02 14:13:05 +00:00
|
|
|
@lu.transformation_with_aux
|
|
|
|
def nonzero_outputs(*args, **kwargs):
|
|
|
|
results = yield args, kwargs
|
|
|
|
yield results, [type(r) is not Zero for r in results]
|
|
|
|
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
|
|
|
|
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
2020-02-05 11:08:21 +01:00
|
|
|
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
2020-02-05 15:38:25 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
2020-12-02 14:13:05 +00:00
|
|
|
fun, nz_arg_cts = nonzero_outputs(fun)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
|
|
|
in_axes, out_axes = params['in_axes'], params['out_axes']
|
|
|
|
new_in_axes = (*[axis for axis, x in zip(in_axes, args)
|
2020-11-05 11:54:05 +00:00
|
|
|
if not is_undefined_primal(x)],
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
*[axis for axis, x in zip(out_axes, ct)
|
|
|
|
if type(x) is not Zero])
|
|
|
|
# The interim strategy we use below (until avals-with-names) only works
|
|
|
|
# when all outputs are mapped.
|
|
|
|
assert all(out_axis is not None for out_axis in out_axes), out_axes
|
2020-12-02 14:13:05 +00:00
|
|
|
# NOTE: This assumes that the output cotangents being zero is a deterministic
|
|
|
|
# function of which input cotangents were zero.
|
|
|
|
@as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct)))
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def out_axes_thunk():
|
|
|
|
return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
in_axes=new_in_axes, out_axes_thunk=out_axes_thunk)
|
|
|
|
del new_params['out_axes']
|
2020-06-23 09:39:45 -07:00
|
|
|
update_params = call_transpose_param_updaters.get(primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, map(is_undefined_primal, args),
|
|
|
|
[type(x) is not Zero for x in ct])
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
out_flat = primitive.bind(fun, *all_args, **new_params)
|
2020-01-07 13:11:32 -08:00
|
|
|
arg_cts = tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
|
|
|
# The freevars are being fanned out (not mapped). During transpose the
|
|
|
|
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
|
2020-11-05 11:54:05 +00:00
|
|
|
assert len(in_axes) == len(arg_cts)
|
|
|
|
def unmap_zero(zero, in_axis):
|
|
|
|
return (zero if in_axis is None else
|
|
|
|
Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval)))
|
|
|
|
arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
arg_ct if in_axis is not None else
|
2020-11-05 11:54:05 +00:00
|
|
|
arg_ct.sum(0)
|
|
|
|
for arg_ct, in_axis in zip(arg_cts, in_axes))
|
|
|
|
return tuple(arg_cts)
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2019-05-10 08:58:05 -07:00
|
|
|
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert len(jaxpr.in_avals) == len(nonzeros)
|
2020-01-05 04:35:34 +01:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-05-10 08:58:05 -07:00
|
|
|
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
2019-07-27 15:46:14 -07:00
|
|
|
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
|
|
|
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
2020-09-15 08:06:46 -07:00
|
|
|
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
2020-09-18 10:07:13 -07:00
|
|
|
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
|
|
|
|
num_primals = len(nonzeros)
|
|
|
|
primals = list(primals_and_nztangents[:num_primals])
|
|
|
|
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
|
2020-05-27 13:57:47 +00:00
|
|
|
tangents = [next(nonzero_tangents) if nz else Zero.from_value(p)
|
2020-06-08 09:47:32 -07:00
|
|
|
for p, nz in zip(primals, nonzeros)]
|
2019-07-27 15:46:14 -07:00
|
|
|
primals_out, tangents_out = yield (primals, tangents), {}
|
2020-05-27 13:57:47 +00:00
|
|
|
out_nonzeros = [type(t) is not Zero for t in tangents_out]
|
|
|
|
nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
|
2019-07-27 15:46:14 -07:00
|
|
|
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
2019-07-27 15:46:14 -07:00
|
|
|
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
|
|
|
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
2020-01-07 13:11:32 -08:00
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
2019-07-27 15:46:14 -07:00
|
|
|
new_invars, new_outvars, jaxpr.jaxpr.eqns)
|
2020-09-18 10:07:13 -07:00
|
|
|
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def _perm(primal_counts, tangent_counts, lst):
|
|
|
|
n = sum(primal_counts)
|
|
|
|
primals, tangents = lst[:n], lst[n:]
|
|
|
|
primal_groups = split_list(primals, primal_counts[:-1])
|
|
|
|
tangent_groups = split_list(tangents, tangent_counts[:-1])
|
|
|
|
return _interleave(primal_groups, tangent_groups)
|
|
|
|
|
|
|
|
def _interleave(xs, ys):
|
|
|
|
assert len(xs) == len(ys)
|
|
|
|
return [e for pair in zip(xs, ys) for l in pair for e in l]
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
custom_lin_p = core.Primitive('custom_lin')
|
|
|
|
custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
|
|
|
|
custom_lin_p.multiple_results = True
|
|
|
|
|
|
|
|
def _raise_custom_vjp_error_on_jvp(*_, **__):
|
|
|
|
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
|
|
|
|
"function.")
|
|
|
|
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
|
|
|
|
|
|
|
|
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
|
|
|
|
res, _ = split_list(invals, [num_res])
|
|
|
|
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
|
|
|
|
cts_in = bwd.call_wrapped(*res, *cts_out)
|
2020-10-16 00:21:04 -07:00
|
|
|
return [None] * num_res + list(cts_in)
|
2020-03-28 14:15:46 -07:00
|
|
|
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
|
|
|
|
|
|
|
|
2020-03-23 14:29:22 -07:00
|
|
|
# TODO(mattjj): delete everything below here (deprecated custom_transforms)
|
|
|
|
|
|
|
|
def defvjp_all(prim, custom_vjp):
|
|
|
|
# see https://github.com/google/jax/pull/636
|
|
|
|
name = prim.name
|
|
|
|
|
|
|
|
def fun_jvp(xs, ts, **params):
|
2020-05-28 13:20:56 +00:00
|
|
|
ts = map(instantiate_zeros, ts)
|
2020-03-23 14:29:22 -07:00
|
|
|
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
|
|
|
|
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
|
|
|
|
if prim.multiple_results:
|
|
|
|
return primals, tangents
|
|
|
|
else:
|
|
|
|
primal, = primals
|
|
|
|
tangent, = tangents
|
|
|
|
return primal, tangent
|
|
|
|
primitive_jvps[prim] = fun_jvp
|
|
|
|
|
|
|
|
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
|
|
|
fun_jvp_p.multiple_results = True
|
|
|
|
def fun_jvp_partial_eval(trace, *tracers, **params):
|
|
|
|
primals, tangents = split_list(tracers, [len(tracers) // 2])
|
|
|
|
primals_out, vjp_py = custom_vjp(*primals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
primals_out = [primals_out]
|
|
|
|
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
|
2020-03-18 07:11:44 +01:00
|
|
|
ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
|
|
|
else:
|
2020-09-15 08:06:46 -07:00
|
|
|
with core.initial_style_staging(): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
|
|
|
|
instantiate=True)
|
2020-03-23 14:29:22 -07:00
|
|
|
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
|
|
|
num_res=len(res), out_avals=out_avals)
|
|
|
|
return primals_out + tangents_out
|
|
|
|
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
|
|
|
|
|
|
|
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
|
|
|
fun_lin_p.multiple_results = True
|
|
|
|
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
|
|
|
|
def fun_lin_transpose(cts, *args, **kwargs):
|
|
|
|
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
|
|
|
|
res, _ = split_list(args, [num_res])
|
|
|
|
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
|
|
|
|
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
|
|
|
|
return [None] * num_res + outs
|
|
|
|
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
|
|
|
|
|
|
|
def defvjp(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
def defvjp2(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
class CustomJVPException(Exception):
|
|
|
|
def __init__(self):
|
|
|
|
# TODO(mattjj): track source provenance on AD tracers, improve error
|
|
|
|
msg = ("Detected differentiation of a custom_jvp function with respect to "
|
|
|
|
"a closed-over value. That isn't supported because the custom JVP "
|
|
|
|
"rule only specifies how to differentiate the custom_jvp function "
|
|
|
|
"with respect to explicit input parameters. Try passing the "
|
|
|
|
"closed-over value into the custom_jvp function as an argument, and "
|
|
|
|
"adapting the custom_jvp rule.")
|
|
|
|
super().__init__(msg)
|
|
|
|
|
|
|
|
class CustomVJPException(Exception):
|
|
|
|
def __init__(self):
|
|
|
|
# TODO(mattjj): track source provenance on AD tracers, improve error
|
|
|
|
msg = ("Detected differentiation of a custom_vjp function with respect to "
|
|
|
|
"a closed-over value. That isn't supported because the custom VJP "
|
|
|
|
"rule only specifies how to differentiate the custom_vjp function "
|
|
|
|
"with respect to explicit input parameters. Try passing the "
|
|
|
|
"closed-over value into the custom_vjp function as an argument, and "
|
|
|
|
"adapting the custom_vjp fwd and bwd rules.")
|
|
|
|
super().__init__(msg)
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
@config.register_omnistaging_disabler
|
|
|
|
def omnistaging_disabler() -> None:
|
2020-07-30 12:59:36 -07:00
|
|
|
global jvp_jaxpr
|
|
|
|
|
|
|
|
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
|
|
|
assert len(jaxpr.in_avals) == len(nonzeros)
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
|
|
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
|
|
|
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
|
|
|
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
2020-09-15 08:06:46 -07:00
|
|
|
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
2020-09-18 10:07:13 -07:00
|
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
|
|
|
return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros()
|