2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import functools
|
2019-02-23 20:34:14 -08:00
|
|
|
import itertools as it
|
2020-04-20 12:24:05 +02:00
|
|
|
from typing import Any, Callable, Dict, Set, List
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import partial_eval as pe
|
|
|
|
from .. import core as core
|
2019-07-26 16:48:17 -04:00
|
|
|
from ..core import Trace, Tracer, new_master, get_aval, call_p, Primitive, Literal
|
2019-05-07 08:52:08 -07:00
|
|
|
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
2020-01-05 04:35:34 +01:00
|
|
|
zeros_like_p, zero)
|
2019-04-23 17:47:28 -07:00
|
|
|
from ..abstract_arrays import raise_to_shaped
|
2020-01-26 23:27:56 -08:00
|
|
|
from ..util import unzip2, safe_map, safe_zip, partial, split_list, wrap_name
|
2020-01-05 04:35:34 +01:00
|
|
|
from ..tree_util import register_pytree_node
|
|
|
|
from .. import linear_util as lu
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
2019-07-26 23:17:21 -04:00
|
|
|
from ..tree_util import tree_flatten, tree_unflatten
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
zip = safe_zip
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(jvp_subtrace(fun), instantiate)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, aux = jvp_subtrace_aux(fun)
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(fun, instantiate), aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def jvpfun(instantiate, primals, tangents):
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JVPTrace) as master:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, out_tangents = yield (master, primals, tangents), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
del master
|
2019-07-27 15:46:14 -07:00
|
|
|
if type(instantiate) is bool:
|
|
|
|
instantiate = [instantiate] * len(out_tangents)
|
|
|
|
out_tangents = [instantiate_zeros(x, t) if inst else t for x, t, inst
|
|
|
|
in zip(out_primals, out_tangents, instantiate)]
|
|
|
|
yield out_primals, out_tangents
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2018-11-17 18:03:33 -08:00
|
|
|
def jvp_subtrace(master, primals, tangents):
|
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
2020-01-29 16:23:27 -05:00
|
|
|
assert x._trace.level < trace.level
|
2019-09-09 17:47:15 -07:00
|
|
|
in_tracers = [JVPTracer(trace, x, t) if t is not zero else x
|
|
|
|
for x, t in zip(primals, tangents)]
|
|
|
|
ans = yield in_tracers, {}
|
2019-07-26 23:17:21 -04:00
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
yield unzip2([(out_tracer.primal, out_tracer.tangent)
|
|
|
|
for out_tracer in out_tracers])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def jvp_subtrace_aux(master, primals, tangents):
|
2019-03-07 14:08:02 -08:00
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
2020-01-29 16:23:27 -05:00
|
|
|
assert x._trace.level < trace.level
|
2019-04-10 22:09:14 -07:00
|
|
|
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
ans_tracers = map(trace.full_raise, ans)
|
|
|
|
aux_tracers = map(trace.full_raise, aux)
|
|
|
|
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
|
|
|
|
aux_primals, _ = unzip2((t.primal, t.tangent) for t in aux_tracers)
|
2020-01-06 18:08:00 -08:00
|
|
|
aux_primals = map(core.full_lower, aux_primals)
|
2019-07-27 15:46:14 -07:00
|
|
|
yield (out_primals, out_tangents), aux_primals
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
def linearize(traceable, *primals, **kwargs):
|
|
|
|
has_aux = kwargs.pop('has_aux', False)
|
|
|
|
if not has_aux:
|
2019-07-26 23:17:21 -04:00
|
|
|
jvpfun = jvp(traceable)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
|
|
|
jvpfun, aux = jvp(traceable, has_aux=True)
|
2019-07-26 23:17:21 -04:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
|
|
|
|
+ tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
|
2019-07-26 23:17:21 -04:00
|
|
|
for p in primals))
|
|
|
|
_, in_tree = tree_flatten(((primals, primals), {}))
|
|
|
|
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|
|
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
|
2020-03-18 07:11:44 +01:00
|
|
|
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
|
|
|
|
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
|
|
|
|
_, out_primals_consts = unzip2(out_primals_pvals)
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2020-03-18 07:11:44 +01:00
|
|
|
return out_primals_consts, out_tangents_pvals, jaxpr, consts
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def vjp(traceable, primals, has_aux=False):
|
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
|
|
|
|
def vjp_(*cts):
|
|
|
|
cts = tuple(map(ignore_consts, cts, pvals))
|
|
|
|
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
2020-01-07 13:11:32 -08:00
|
|
|
arg_cts = backward_pass(jaxpr, consts, dummy_args, dummy_primals_and_cts)
|
2019-07-27 15:46:14 -07:00
|
|
|
arg_cts = arg_cts[len(primals):]
|
|
|
|
return map(instantiate_zeros, primals, arg_cts)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_, aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-03 22:24:46 -05:00
|
|
|
def ignore_consts(ct, pval):
|
|
|
|
aval, const = pval
|
|
|
|
if isinstance(aval, core.AbstractValue):
|
|
|
|
return ct
|
|
|
|
elif aval is None:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
def unpair_pval(pval):
|
|
|
|
aval, const = pval
|
|
|
|
const_1, const_2 = const
|
|
|
|
if aval is None:
|
|
|
|
return (None, const_1), (None, const_2)
|
|
|
|
else:
|
|
|
|
aval_1, aval_2 = aval
|
|
|
|
return (aval_1, const_1), (aval_2, const_2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-17 11:20:54 +00:00
|
|
|
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
|
|
|
|
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
2019-11-22 10:53:11 -08:00
|
|
|
if all(ct is zero for ct in cotangents_in):
|
2020-01-07 13:11:32 -08:00
|
|
|
return [zero] * len(jaxpr.invars)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def write_cotangent(v, ct):
|
|
|
|
# assert v not in primal_env
|
2020-04-13 09:44:13 -07:00
|
|
|
if ct is not None and type(v) is not Literal and ct is not zero:
|
2018-11-17 18:03:33 -08:00
|
|
|
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
2020-04-13 09:44:13 -07:00
|
|
|
if not core.skip_checks:
|
|
|
|
ct_aval = core.get_aval(ct_env[v])
|
|
|
|
assert v.aval == core.lattice_join(v.aval, ct_aval)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def read_cotangent(v):
|
|
|
|
return ct_env.get(v, zero)
|
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def read_primal(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return primal_env.get(v, UndefinedPrimal(v.aval))
|
2019-04-25 10:43:50 -07:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def write_primal(v, val):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if not is_undefined_primal(val):
|
2019-05-01 15:47:01 -07:00
|
|
|
primal_env[v] = val
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primal_env: Dict[Any, Any] = {}
|
2019-11-22 10:53:11 -08:00
|
|
|
write_primal(core.unitvar, core.unit)
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_primal, jaxpr.constvars, consts)
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: invars can contain both primal and tangent values, and this line
|
|
|
|
# forces primal_in to contain UndefinedPrimals for tangent values!
|
|
|
|
map(write_primal, jaxpr.invars, primals_in)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def is_linear(var):
|
|
|
|
if type(var) is Literal:
|
|
|
|
return False
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return var not in primal_env
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
linear_eqns = []
|
|
|
|
for eqn in jaxpr.eqns:
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
prim = eqn.primitive
|
|
|
|
if not (prim.call_primitive or prim.map_primitive):
|
2019-11-22 10:53:11 -08:00
|
|
|
if any(is_linear(v) for v in eqn.invars):
|
|
|
|
linear_eqns.append(eqn)
|
|
|
|
else:
|
|
|
|
in_vals = map(read_primal, eqn.invars)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
ans = prim.bind(*in_vals, **eqn.params)
|
|
|
|
if prim.multiple_results:
|
2019-11-22 10:53:11 -08:00
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_primal(eqn.outvars[0], ans)
|
|
|
|
else:
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(prim, eqn.params)
|
2020-01-07 13:11:32 -08:00
|
|
|
if any(is_linear(v) for v in eqn.invars):
|
2019-11-22 10:53:11 -08:00
|
|
|
linear_eqns.append(eqn)
|
2020-02-11 15:56:53 -08:00
|
|
|
if any(not is_linear(v) for v in eqn.invars):
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: Some invars correspond to tangents
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
ans = _eval_subjaxpr_primals(prim, call_jaxpr,
|
2020-02-11 15:56:53 -08:00
|
|
|
map(read_primal, eqn.invars), params)
|
2019-11-27 14:28:13 -08:00
|
|
|
map(write_primal, eqn.outvars, ans)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-04-20 12:24:05 +02:00
|
|
|
# Find the last use of each cotangent so that they can be removed
|
|
|
|
# as soon as possible.
|
|
|
|
drop_cts: List[Set[Any]] = []
|
|
|
|
seen_vars: Set[Any] = set(jaxpr.invars)
|
|
|
|
for eqn in linear_eqns:
|
|
|
|
read_set = set(eqn.outvars) # NOTE: eqn is not transposed yet!
|
|
|
|
drop_cts.append(read_set - seen_vars)
|
|
|
|
seen_vars |= read_set
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
ct_env: Dict[Any, Any] = {}
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_cotangent, jaxpr.outvars, cotangents_in)
|
2020-04-20 12:24:05 +02:00
|
|
|
for eqn, to_drop in zip(linear_eqns[::-1], drop_cts[::-1]):
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: Some invars correspond to tangents
|
2019-07-27 15:46:14 -07:00
|
|
|
invals = map(read_primal, eqn.invars)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
cts_in = map(read_cotangent, eqn.outvars)
|
2019-04-25 10:43:50 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_in, = map(read_cotangent, eqn.outvars)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
|
2020-02-05 15:38:25 +01:00
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
2020-01-07 13:11:32 -08:00
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(
|
2020-02-05 15:38:25 +01:00
|
|
|
params, call_jaxpr, invals, cts_in)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
|
|
|
|
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
|
2020-04-17 11:20:54 +00:00
|
|
|
# FIXME: Some invars correspond to primals!
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_cotangent, eqn.invars, cts_out)
|
2020-04-20 12:24:05 +02:00
|
|
|
for var in to_drop:
|
|
|
|
ct_env.pop(var, None) # NB: Constant cotangents might be missing
|
2018-12-03 22:24:46 -05:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
cotangents_out = map(read_cotangent, jaxpr.invars)
|
2020-01-07 13:11:32 -08:00
|
|
|
return cotangents_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-02-03 20:58:56 +01:00
|
|
|
def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
all_args, in_tree_def = tree_flatten((in_vals,))
|
2020-01-05 04:35:34 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
|
2019-11-27 15:25:49 -08:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
assert prim.map_primitive ^ prim.call_primitive
|
|
|
|
if prim.map_primitive:
|
|
|
|
new_mapped_invars = [m for m, x in zip(params['mapped_invars'], in_vals)
|
|
|
|
if not is_undefined_primal(x)]
|
|
|
|
new_params = dict(params, mapped_invars=tuple(new_mapped_invars))
|
|
|
|
out_flat = prim.bind(fun, *all_args, **new_params)
|
|
|
|
else:
|
|
|
|
out_flat = prim.bind(fun, *all_args, **params)
|
2019-11-27 15:25:49 -08:00
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
2020-02-03 20:58:56 +01:00
|
|
|
def _eval_primals(jaxpr, args):
|
2019-11-22 10:53:11 -08:00
|
|
|
primal_env = {}
|
|
|
|
|
|
|
|
def read_primal(v):
|
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return primal_env.get(v, UndefinedPrimal(v.aval))
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def write_primal(v, val):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if not is_undefined_primal(val):
|
2019-11-22 10:53:11 -08:00
|
|
|
primal_env[v] = val
|
|
|
|
|
|
|
|
def is_linear(var):
|
|
|
|
if type(var) is Literal:
|
|
|
|
return False
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return var not in primal_env
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
write_primal(core.unitvar, core.unit)
|
2020-02-03 20:58:56 +01:00
|
|
|
assert not jaxpr.constvars
|
2019-11-22 10:53:11 -08:00
|
|
|
map(write_primal, jaxpr.invars, args)
|
|
|
|
for eqn in jaxpr.eqns:
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if not (eqn.primitive.call_primitive or eqn.primitive.map_primitive):
|
2019-11-22 10:53:11 -08:00
|
|
|
if not any(is_linear(v) for v in eqn.invars):
|
|
|
|
in_vals = map(read_primal, eqn.invars)
|
|
|
|
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_primal(eqn.outvars[0], ans)
|
|
|
|
else:
|
2020-02-11 15:56:53 -08:00
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
|
|
|
if any(not is_linear(v) for v in eqn.invars):
|
|
|
|
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
|
|
|
|
map(read_primal, eqn.invars), params)
|
2019-11-27 15:25:49 -08:00
|
|
|
map(write_primal, eqn.outvars, ans)
|
2019-11-22 10:53:11 -08:00
|
|
|
return map(read_primal, jaxpr.outvars)
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
class UndefinedPrimal:
|
|
|
|
__slots__ = ['aval']
|
|
|
|
def __init__(self, aval):
|
|
|
|
self.aval = aval
|
|
|
|
def __repr__(self):
|
|
|
|
return 'UndefinedPrimal({})'.format(self.aval)
|
|
|
|
|
|
|
|
def is_undefined_primal(x):
|
|
|
|
return type(x) is UndefinedPrimal
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
register_pytree_node(UndefinedPrimal,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
lambda z: ((), z.aval),
|
|
|
|
lambda aval, _: UndefinedPrimal(aval))
|
2019-05-07 08:52:08 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def get_primitive_transpose(p):
|
|
|
|
try:
|
|
|
|
return primitive_transposes[p]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise NotImplementedError(
|
2020-01-15 15:00:38 -08:00
|
|
|
"Transpose rule (for reverse-mode differentiation) for '{}' "
|
|
|
|
"not implemented".format(p)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class JVPTrace(Trace):
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
2019-01-24 16:08:03 -08:00
|
|
|
def sublift(self, val):
|
2018-11-17 18:03:33 -08:00
|
|
|
return JVPTracer(self, val.primal, val.tangent)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
jvp = primitive_jvps[primitive]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Forward-mode differentiation rule for '{}' not implemented"
|
2020-03-09 22:06:12 +02:00
|
|
|
.format(primitive)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
2019-07-27 15:46:14 -07:00
|
|
|
if primitive.multiple_results:
|
|
|
|
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
|
|
|
|
else:
|
|
|
|
return JVPTracer(self, primal_out, tangent_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
2020-03-28 14:15:46 -07:00
|
|
|
nonzero_tangents, in_tree_def = tree_flatten(tangents)
|
|
|
|
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
|
|
|
|
len(primals), in_tree_def)
|
|
|
|
name = params.get('name', f.__name__)
|
|
|
|
params = dict(params, name=wrap_name(name, 'jvp'))
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
|
2020-03-28 14:15:46 -07:00
|
|
|
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
|
|
|
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
|
|
|
out = primals + tangents
|
|
|
|
del primals, tangents
|
2018-11-17 18:03:33 -08:00
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
n = len(x) // 2
|
|
|
|
primals, tangents = x[:n], x[n:]
|
2018-11-17 18:03:33 -08:00
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(partial(JVPTracer, trace), primals, tangents)
|
|
|
|
return out, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
|
|
|
# only differs from process_call in that it must update mapped_invars
|
|
|
|
# TODO de-duplicate code
|
|
|
|
assert map_primitive.multiple_results
|
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
|
|
|
nonzero_tangents, in_tree_def = tree_flatten(tangents)
|
|
|
|
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
|
|
|
|
len(primals), in_tree_def)
|
|
|
|
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
|
|
|
|
new_mapped_invars = (*params['mapped_invars'],
|
|
|
|
*[m for m, t in zip(params['mapped_invars'], tangents)
|
|
|
|
if t is not zero])
|
|
|
|
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars)
|
|
|
|
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
|
|
|
|
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
|
|
|
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
2020-04-21 18:12:02 -07:00
|
|
|
post_process_map = post_process_call
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_jvp_call(self, _, __, f_jvp, tracers):
|
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
|
|
|
primals_in = map(core.full_lower, primals_in)
|
|
|
|
tangents_in = map(instantiate_zeros, primals_in, tangents_in)
|
|
|
|
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
|
|
|
|
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
|
|
|
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
|
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
|
|
|
tangents_in = map(instantiate_zeros, primals_in, tangents_in)
|
|
|
|
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
|
|
|
|
out_tree, res_tree = out_trees()
|
|
|
|
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
|
|
|
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
|
|
|
tangents_out = custom_lin_p.bind(
|
|
|
|
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
|
|
|
avals_out=avals_out)
|
|
|
|
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def join(self, xt, yt):
|
2019-07-27 15:46:14 -07:00
|
|
|
xz, yz = xt is zero, yt is zero
|
|
|
|
if xz == yz:
|
2018-11-17 18:03:33 -08:00
|
|
|
return xt, yt
|
2019-07-27 15:46:14 -07:00
|
|
|
elif yz and not xz:
|
|
|
|
return xt, zeros_like_jaxval(xt)
|
|
|
|
elif xz and not yz:
|
|
|
|
return zeros_like_jaxval(yt), yt
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError((xt, yt))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
class JVPTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __init__(self, trace, primal, tangent):
|
2019-05-10 15:52:12 -07:00
|
|
|
if not core.skip_checks:
|
|
|
|
_primal_tangent_shapes_match(primal, tangent)
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.primal = primal
|
|
|
|
self.tangent = tangent
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
# TODO(dougalm): add epsilon ball
|
|
|
|
return get_aval(self.primal)
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
if self.tangent is zero:
|
|
|
|
return core.full_lower(self.primal)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2019-05-10 15:52:12 -07:00
|
|
|
def _primal_tangent_shapes_match(primal, tangent):
|
2019-07-27 15:46:14 -07:00
|
|
|
if tangent is not zero:
|
2019-05-10 15:52:12 -07:00
|
|
|
primal_aval = raise_to_shaped(get_aval(primal))
|
|
|
|
tangent_aval = raise_to_shaped(get_aval(tangent))
|
|
|
|
assert primal_aval == tangent_aval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# -------------------- Primitives --------------------
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
primitive_jvps : Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primitive_transposes: Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def deflinear(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
|
|
|
|
|
|
|
def linear_jvp(primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
if all(tangent is zero for tangent in tangents):
|
|
|
|
return val_out, zero
|
|
|
|
else:
|
|
|
|
tangents = map(instantiate_zeros, primals, tangents)
|
|
|
|
return val_out, primitive.bind(*tangents, **params)
|
|
|
|
|
|
|
|
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
|
|
|
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
|
|
|
|
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def deflinear2(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
|
|
|
|
|
|
|
|
def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
|
|
|
|
return zero if cotangent is zero else transpose_rule(cotangent, *args, **kwargs)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def defjvp(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
|
|
|
|
|
|
|
|
|
|
|
|
def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
2019-02-20 12:36:18 -08:00
|
|
|
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero]
|
2020-01-08 13:17:55 -05:00
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defjvp2(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
|
|
|
|
|
|
|
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero)
|
2020-01-08 13:17:55 -05:00
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def add_tangents(x, y):
|
|
|
|
if x is zero:
|
|
|
|
return y
|
|
|
|
elif y is zero:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return add_jaxvals(x, y)
|
|
|
|
|
|
|
|
|
|
|
|
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
|
|
|
assert isinstance(prim, Primitive)
|
|
|
|
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
|
|
|
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
|
|
|
|
defjvp(prim, lhs_jvp, rhs_jvp)
|
|
|
|
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
|
|
|
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
|
|
|
|
|
|
|
|
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert is_undefined_primal(x) ^ is_undefined_primal(y)
|
|
|
|
if is_undefined_primal(x):
|
2018-11-17 18:03:33 -08:00
|
|
|
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
|
|
|
|
return out, None
|
|
|
|
else:
|
|
|
|
out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)
|
|
|
|
return None, out
|
|
|
|
|
|
|
|
|
|
|
|
def defjvp_zero(primitive):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
|
|
|
|
|
|
|
def zero_jvp(primitive, primals, tangents, **params):
|
|
|
|
return primitive.bind(*primals, **params), zero
|
|
|
|
|
|
|
|
|
2018-12-11 09:18:38 -08:00
|
|
|
deflinear(zeros_like_p, lambda t: [zero])
|
2018-11-17 18:03:33 -08:00
|
|
|
deflinear(core.identity_p, lambda t: (t,))
|
|
|
|
deflinear(add_jaxvals_p, lambda t: (t, t))
|
|
|
|
|
|
|
|
def instantiate_zeros(example, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_jaxval(example)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2019-05-07 08:52:08 -07:00
|
|
|
def instantiate_zeros_aval(aval, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_aval(aval)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
|
|
|
new_primals = primals_and_tangents[:num_primals]
|
|
|
|
new_tangents = primals_and_tangents[num_primals:]
|
|
|
|
new_tangents = tree_unflatten(in_tree_def, new_tangents)
|
2019-04-10 22:09:14 -07:00
|
|
|
primal_out, tangent_out = yield (new_primals, new_tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
|
|
|
|
yield out_flat, tree_def
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
def call_transpose(primitive, params, call_jaxpr, args, ct):
|
2020-02-05 11:08:21 +01:00
|
|
|
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
2020-02-05 15:38:25 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
2020-01-26 23:27:56 -08:00
|
|
|
params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = primitive.bind(fun, *all_args, **params)
|
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
2019-11-22 10:53:11 -08:00
|
|
|
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
def map_transpose(primitive, params, call_jaxpr, args, ct):
|
2020-02-05 11:08:21 +01:00
|
|
|
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
2020-02-05 15:38:25 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
|
|
|
|
if not is_undefined_primal(x)],
|
|
|
|
*[True for x in ct if x is not zero])
|
|
|
|
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
|
|
|
|
mapped_invars=new_mapped_invars)
|
|
|
|
out_flat = primitive.bind(fun, *all_args, **new_params)
|
2020-01-07 13:11:32 -08:00
|
|
|
arg_cts = tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
|
|
|
mapped_invars = params['mapped_invars'] # True for each mapped invar
|
|
|
|
# The freevars are being fanned out (not mapped). During transpose the
|
|
|
|
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
|
|
|
|
assert len(mapped_invars) == len(arg_cts)
|
|
|
|
arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct.sum(0)
|
|
|
|
for arg_ct, arg_mapped in zip(arg_cts, mapped_invars))
|
|
|
|
|
|
|
|
return arg_cts
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2019-05-10 08:58:05 -07:00
|
|
|
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert len(jaxpr.in_avals) == len(nonzeros)
|
2020-01-05 04:35:34 +01:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-05-10 08:58:05 -07:00
|
|
|
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
2019-07-27 15:46:14 -07:00
|
|
|
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
|
|
|
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
|
|
|
avals_out, _ = unzip2(pvals_out)
|
|
|
|
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
2019-05-10 08:20:40 -07:00
|
|
|
return jaxpr_out, out_nonzeros()
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
|
|
|
|
num_primals = len(nonzeros)
|
|
|
|
primals = list(primals_and_nztangents[:num_primals])
|
|
|
|
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
|
|
|
|
tangents = [next(nonzero_tangents) if nz else zero for nz in nonzeros]
|
|
|
|
primals_out, tangents_out = yield (primals, tangents), {}
|
|
|
|
out_nonzeros = [t is not zero for t in tangents_out]
|
|
|
|
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
|
|
|
|
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
|
|
|
|
|
2020-01-07 13:11:32 -08:00
|
|
|
def rearrange_binders(jaxpr: core.TypedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
2019-07-27 15:46:14 -07:00
|
|
|
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
|
|
|
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
2020-01-07 13:11:32 -08:00
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
2019-07-27 15:46:14 -07:00
|
|
|
new_invars, new_outvars, jaxpr.jaxpr.eqns)
|
|
|
|
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
|
|
|
|
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)
|
|
|
|
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, jaxpr.literals, new_in_avals,
|
|
|
|
new_out_avals)
|
|
|
|
return new_typed_jaxpr
|
|
|
|
|
|
|
|
def _perm(primal_counts, tangent_counts, lst):
|
|
|
|
n = sum(primal_counts)
|
|
|
|
primals, tangents = lst[:n], lst[n:]
|
|
|
|
primal_groups = split_list(primals, primal_counts[:-1])
|
|
|
|
tangent_groups = split_list(tangents, tangent_counts[:-1])
|
|
|
|
return _interleave(primal_groups, tangent_groups)
|
|
|
|
|
|
|
|
def _interleave(xs, ys):
|
|
|
|
assert len(xs) == len(ys)
|
|
|
|
return [e for pair in zip(xs, ys) for l in pair for e in l]
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
custom_lin_p = core.Primitive('custom_lin')
|
|
|
|
custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
|
|
|
|
custom_lin_p.multiple_results = True
|
|
|
|
|
|
|
|
def _raise_custom_vjp_error_on_jvp(*_, **__):
|
|
|
|
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
|
|
|
|
"function.")
|
|
|
|
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
|
|
|
|
|
|
|
|
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
|
|
|
|
res, _ = split_list(invals, [num_res])
|
|
|
|
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
|
|
|
|
cts_in = bwd.call_wrapped(*res, *cts_out)
|
|
|
|
cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
|
|
|
|
return [None] * num_res + cts_in_flat
|
|
|
|
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
|
|
|
|
|
|
|
|
2020-03-23 14:29:22 -07:00
|
|
|
# TODO(mattjj): delete everything below here (deprecated custom_transforms)
|
|
|
|
|
|
|
|
def defvjp_all(prim, custom_vjp):
|
|
|
|
# see https://github.com/google/jax/pull/636
|
|
|
|
name = prim.name
|
|
|
|
|
|
|
|
def fun_jvp(xs, ts, **params):
|
|
|
|
ts = map(instantiate_zeros, xs, ts)
|
|
|
|
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
|
|
|
|
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
|
|
|
|
if prim.multiple_results:
|
|
|
|
return primals, tangents
|
|
|
|
else:
|
|
|
|
primal, = primals
|
|
|
|
tangent, = tangents
|
|
|
|
return primal, tangent
|
|
|
|
primitive_jvps[prim] = fun_jvp
|
|
|
|
|
|
|
|
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
|
|
|
fun_jvp_p.multiple_results = True
|
|
|
|
def fun_jvp_partial_eval(trace, *tracers, **params):
|
|
|
|
primals, tangents = split_list(tracers, [len(tracers) // 2])
|
|
|
|
primals_out, vjp_py = custom_vjp(*primals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
primals_out = [primals_out]
|
|
|
|
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
|
2020-03-18 07:11:44 +01:00
|
|
|
ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
|
2020-03-30 13:49:56 -07:00
|
|
|
with core.initial_style_staging():
|
|
|
|
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
|
|
|
|
instantiate=True)
|
2020-03-23 14:29:22 -07:00
|
|
|
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
|
|
|
num_res=len(res), out_avals=out_avals)
|
|
|
|
return primals_out + tangents_out
|
|
|
|
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
|
|
|
|
|
|
|
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
|
|
|
fun_lin_p.multiple_results = True
|
|
|
|
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
|
|
|
|
def fun_lin_transpose(cts, *args, **kwargs):
|
|
|
|
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
|
|
|
|
res, _ = split_list(args, [num_res])
|
|
|
|
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
|
|
|
|
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
|
|
|
|
return [None] * num_res + outs
|
|
|
|
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
|
|
|
|
|
|
|
def defvjp(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
def defvjp2(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|