708 lines
27 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
2019-02-23 20:34:14 -08:00
import itertools as it
Delete cotangent references on their last use (#2719) * Delete cotangent references on their last use Current implementation of transposition may add a factor of 2x to peak memory usage in real cases and _potentially an unbounded factor_ in pathological programs. The reason why this happens is because the cotangents computed by the `backward_pass` are never evicted from the environment until the whole transposition is complete. Other systems (e.g. PyTorch) generally make use of refcounting or liveness analysis to remove unnecessary references as soon as they are known to no longer be needed. A simple example that showcases this issue is this: ```python def f(x): for i in range(1000): x = x * 4 return x x = np.ones(4) vjp(f, x)[1](x) ``` Adding `print(len(ct_env))` at the end of `backward_pass` reveals that the dictionary actually holds a thousand `DeviceArray`s, while both the forward and backward can be computed in constant memory. Of course this is the pathological example I mentioned above, but one can easily see that keeping the cotangents alive for the whole duration of differentiation causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates` instead of `fwd_memory + program_pathwidth` where: * `fwd_coefs` is the amount of memory necessary to store all constant coefficients of the linearized function * `all_fwd_intermediates` is the amount of memory necessary to store _all intermediates appearing in the forward program_. * `program_pathwidth` is the maximum over amounts of memory necessary to store the live values over all transposed program locations Note that usually we have that `all_fwd_intermediates > fwd_coefs >> program_pathwidth` (`>>` meaning that the RHS is usually significantly smaller). * Import Set * Use a list instead of a dict * Type annotation * Import List
2020-04-20 12:24:05 +02:00
from typing import Any, Callable, Dict, Set, List
2019-02-23 20:34:14 -08:00
2018-11-17 18:03:33 -08:00
from . import partial_eval as pe
from .. import core as core
from ..core import Trace, Tracer, new_master, get_aval, call_p, Primitive, Literal
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
zeros_like_p, zero)
from ..abstract_arrays import raise_to_shaped
from ..util import unzip2, safe_map, safe_zip, partial, split_list, wrap_name
from ..tree_util import register_pytree_node
from .. import linear_util as lu
from ..api_util import flatten_fun, flatten_fun_nokwargs
2019-07-26 23:17:21 -04:00
from ..tree_util import tree_flatten, tree_unflatten
2018-11-17 18:03:33 -08:00
zip = safe_zip
map = safe_map
2019-02-15 06:35:54 -08:00
def identity(x): return x
2018-11-17 18:03:33 -08:00
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
if not has_aux:
return jvpfun(jvp_subtrace(fun), instantiate)
else:
fun, aux = jvp_subtrace_aux(fun)
return jvpfun(fun, instantiate), aux
2018-11-17 18:03:33 -08:00
@lu.transformation
def jvpfun(instantiate, primals, tangents):
2018-11-17 18:03:33 -08:00
with new_master(JVPTrace) as master:
out_primals, out_tangents = yield (master, primals, tangents), {}
2018-11-17 18:03:33 -08:00
del master
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_tangents)
out_tangents = [instantiate_zeros(x, t) if inst else t for x, t, inst
in zip(out_primals, out_tangents, instantiate)]
yield out_primals, out_tangents
@lu.transformation
2018-11-17 18:03:33 -08:00
def jvp_subtrace(master, primals, tangents):
trace = JVPTrace(master, core.cur_sublevel())
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x._trace.level < trace.level
in_tracers = [JVPTracer(trace, x, t) if t is not zero else x
for x, t in zip(primals, tangents)]
ans = yield in_tracers, {}
2019-07-26 23:17:21 -04:00
out_tracers = map(trace.full_raise, ans)
yield unzip2([(out_tracer.primal, out_tracer.tangent)
for out_tracer in out_tracers])
2018-11-17 18:03:33 -08:00
@lu.transformation_with_aux
def jvp_subtrace_aux(master, primals, tangents):
trace = JVPTrace(master, core.cur_sublevel())
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x._trace.level < trace.level
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
ans_tracers = map(trace.full_raise, ans)
aux_tracers = map(trace.full_raise, aux)
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
aux_primals, _ = unzip2((t.primal, t.tangent) for t in aux_tracers)
aux_primals = map(core.full_lower, aux_primals)
yield (out_primals, out_tangents), aux_primals
def linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
if not has_aux:
2019-07-26 23:17:21 -04:00
jvpfun = jvp(traceable)
else:
jvpfun, aux = jvp(traceable, has_aux=True)
2019-07-26 23:17:21 -04:00
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
+ tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
2019-07-26 23:17:21 -04:00
for p in primals))
_, in_tree = tree_flatten(((primals, primals), {}))
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
_, out_primals_consts = unzip2(out_primals_pvals)
if not has_aux:
return out_primals_consts, out_tangents_pvals, jaxpr, consts
else:
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
2018-11-17 18:03:33 -08:00
def vjp(traceable, primals, has_aux=False):
if not has_aux:
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
else:
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
def vjp_(*cts):
cts = tuple(map(ignore_consts, cts, pvals))
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
arg_cts = backward_pass(jaxpr, consts, dummy_args, dummy_primals_and_cts)
arg_cts = arg_cts[len(primals):]
return map(instantiate_zeros, primals, arg_cts)
2018-11-17 18:03:33 -08:00
if not has_aux:
return out_primals, vjp_
else:
return out_primals, vjp_, aux
2018-11-17 18:03:33 -08:00
def ignore_consts(ct, pval):
aval, const = pval
if isinstance(aval, core.AbstractValue):
return ct
elif aval is None:
return core.unit
else:
raise TypeError(aval)
def unpair_pval(pval):
aval, const = pval
const_1, const_2 = const
if aval is None:
return (None, const_1), (None, const_2)
else:
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)
2018-11-17 18:03:33 -08:00
2020-04-17 11:20:54 +00:00
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.invars)
2018-11-17 18:03:33 -08:00
def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None and type(v) is not Literal and ct is not zero:
2018-11-17 18:03:33 -08:00
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
if not core.skip_checks:
ct_aval = core.get_aval(ct_env[v])
assert v.aval == core.lattice_join(v.aval, ct_aval)
2018-11-17 18:03:33 -08:00
def read_cotangent(v):
return ct_env.get(v, zero)
def read_primal(v):
if type(v) is Literal:
return v.val
else:
return primal_env.get(v, UndefinedPrimal(v.aval))
def write_primal(v, val):
if not is_undefined_primal(val):
primal_env[v] = val
primal_env: Dict[Any, Any] = {}
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
2020-04-17 11:20:54 +00:00
# FIXME: invars can contain both primal and tangent values, and this line
# forces primal_in to contain UndefinedPrimals for tangent values!
map(write_primal, jaxpr.invars, primals_in)
2018-11-17 18:03:33 -08:00
def is_linear(var):
if type(var) is Literal:
return False
else:
return var not in primal_env
linear_eqns = []
for eqn in jaxpr.eqns:
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
prim = eqn.primitive
if not (prim.call_primitive or prim.map_primitive):
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
ans = prim.bind(*in_vals, **eqn.params)
if prim.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
call_jaxpr, params = core.extract_call_jaxpr(prim, eqn.params)
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
if any(not is_linear(v) for v in eqn.invars):
2020-04-17 11:20:54 +00:00
# FIXME: Some invars correspond to tangents
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
ans = _eval_subjaxpr_primals(prim, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)
Delete cotangent references on their last use (#2719) * Delete cotangent references on their last use Current implementation of transposition may add a factor of 2x to peak memory usage in real cases and _potentially an unbounded factor_ in pathological programs. The reason why this happens is because the cotangents computed by the `backward_pass` are never evicted from the environment until the whole transposition is complete. Other systems (e.g. PyTorch) generally make use of refcounting or liveness analysis to remove unnecessary references as soon as they are known to no longer be needed. A simple example that showcases this issue is this: ```python def f(x): for i in range(1000): x = x * 4 return x x = np.ones(4) vjp(f, x)[1](x) ``` Adding `print(len(ct_env))` at the end of `backward_pass` reveals that the dictionary actually holds a thousand `DeviceArray`s, while both the forward and backward can be computed in constant memory. Of course this is the pathological example I mentioned above, but one can easily see that keeping the cotangents alive for the whole duration of differentiation causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates` instead of `fwd_memory + program_pathwidth` where: * `fwd_coefs` is the amount of memory necessary to store all constant coefficients of the linearized function * `all_fwd_intermediates` is the amount of memory necessary to store _all intermediates appearing in the forward program_. * `program_pathwidth` is the maximum over amounts of memory necessary to store the live values over all transposed program locations Note that usually we have that `all_fwd_intermediates > fwd_coefs >> program_pathwidth` (`>>` meaning that the RHS is usually significantly smaller). * Import Set * Use a list instead of a dict * Type annotation * Import List
2020-04-20 12:24:05 +02:00
# Find the last use of each cotangent so that they can be removed
# as soon as possible.
drop_cts: List[Set[Any]] = []
seen_vars: Set[Any] = set(jaxpr.invars)
for eqn in linear_eqns:
read_set = set(eqn.outvars) # NOTE: eqn is not transposed yet!
drop_cts.append(read_set - seen_vars)
seen_vars |= read_set
ct_env: Dict[Any, Any] = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)
Delete cotangent references on their last use (#2719) * Delete cotangent references on their last use Current implementation of transposition may add a factor of 2x to peak memory usage in real cases and _potentially an unbounded factor_ in pathological programs. The reason why this happens is because the cotangents computed by the `backward_pass` are never evicted from the environment until the whole transposition is complete. Other systems (e.g. PyTorch) generally make use of refcounting or liveness analysis to remove unnecessary references as soon as they are known to no longer be needed. A simple example that showcases this issue is this: ```python def f(x): for i in range(1000): x = x * 4 return x x = np.ones(4) vjp(f, x)[1](x) ``` Adding `print(len(ct_env))` at the end of `backward_pass` reveals that the dictionary actually holds a thousand `DeviceArray`s, while both the forward and backward can be computed in constant memory. Of course this is the pathological example I mentioned above, but one can easily see that keeping the cotangents alive for the whole duration of differentiation causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates` instead of `fwd_memory + program_pathwidth` where: * `fwd_coefs` is the amount of memory necessary to store all constant coefficients of the linearized function * `all_fwd_intermediates` is the amount of memory necessary to store _all intermediates appearing in the forward program_. * `program_pathwidth` is the maximum over amounts of memory necessary to store the live values over all transposed program locations Note that usually we have that `all_fwd_intermediates > fwd_coefs >> program_pathwidth` (`>>` meaning that the RHS is usually significantly smaller). * Import Set * Use a list instead of a dict * Type annotation * Import List
2020-04-20 12:24:05 +02:00
for eqn, to_drop in zip(linear_eqns[::-1], drop_cts[::-1]):
2020-04-17 11:20:54 +00:00
# FIXME: Some invars correspond to tangents
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in)
2018-11-17 18:03:33 -08:00
else:
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
2020-04-17 11:20:54 +00:00
# FIXME: Some invars correspond to primals!
map(write_cotangent, eqn.invars, cts_out)
Delete cotangent references on their last use (#2719) * Delete cotangent references on their last use Current implementation of transposition may add a factor of 2x to peak memory usage in real cases and _potentially an unbounded factor_ in pathological programs. The reason why this happens is because the cotangents computed by the `backward_pass` are never evicted from the environment until the whole transposition is complete. Other systems (e.g. PyTorch) generally make use of refcounting or liveness analysis to remove unnecessary references as soon as they are known to no longer be needed. A simple example that showcases this issue is this: ```python def f(x): for i in range(1000): x = x * 4 return x x = np.ones(4) vjp(f, x)[1](x) ``` Adding `print(len(ct_env))` at the end of `backward_pass` reveals that the dictionary actually holds a thousand `DeviceArray`s, while both the forward and backward can be computed in constant memory. Of course this is the pathological example I mentioned above, but one can easily see that keeping the cotangents alive for the whole duration of differentiation causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates` instead of `fwd_memory + program_pathwidth` where: * `fwd_coefs` is the amount of memory necessary to store all constant coefficients of the linearized function * `all_fwd_intermediates` is the amount of memory necessary to store _all intermediates appearing in the forward program_. * `program_pathwidth` is the maximum over amounts of memory necessary to store the live values over all transposed program locations Note that usually we have that `all_fwd_intermediates > fwd_coefs >> program_pathwidth` (`>>` meaning that the RHS is usually significantly smaller). * Import Set * Use a list instead of a dict * Type annotation * Import List
2020-04-20 12:24:05 +02:00
for var in to_drop:
ct_env.pop(var, None) # NB: Constant cotangents might be missing
cotangents_out = map(read_cotangent, jaxpr.invars)
return cotangents_out
2018-11-17 18:03:33 -08:00
def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
assert not jaxpr.constvars
all_args, in_tree_def = tree_flatten((in_vals,))
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
assert prim.map_primitive ^ prim.call_primitive
if prim.map_primitive:
new_mapped_invars = [m for m, x in zip(params['mapped_invars'], in_vals)
if not is_undefined_primal(x)]
new_params = dict(params, mapped_invars=tuple(new_mapped_invars))
out_flat = prim.bind(fun, *all_args, **new_params)
else:
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
def _eval_primals(jaxpr, args):
primal_env = {}
def read_primal(v):
if type(v) is Literal:
return v.val
else:
return primal_env.get(v, UndefinedPrimal(v.aval))
def write_primal(v, val):
if not is_undefined_primal(val):
primal_env[v] = val
def is_linear(var):
if type(var) is Literal:
return False
else:
return var not in primal_env
write_primal(core.unitvar, core.unit)
assert not jaxpr.constvars
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
if not (eqn.primitive.call_primitive or eqn.primitive.map_primitive):
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
if any(not is_linear(v) for v in eqn.invars):
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)
class UndefinedPrimal:
__slots__ = ['aval']
def __init__(self, aval):
self.aval = aval
def __repr__(self):
return 'UndefinedPrimal({})'.format(self.aval)
def is_undefined_primal(x):
return type(x) is UndefinedPrimal
register_pytree_node(UndefinedPrimal,
lambda z: ((), z.aval),
lambda aval, _: UndefinedPrimal(aval))
2018-11-17 18:03:33 -08:00
def get_primitive_transpose(p):
try:
return primitive_transposes[p]
except KeyError as err:
2018-11-17 18:03:33 -08:00
raise NotImplementedError(
"Transpose rule (for reverse-mode differentiation) for '{}' "
"not implemented".format(p)) from err
2018-11-17 18:03:33 -08:00
class JVPTrace(Trace):
def pure(self, val):
return JVPTracer(self, val, zero)
def lift(self, val):
return JVPTracer(self, val, zero)
def sublift(self, val):
2018-11-17 18:03:33 -08:00
return JVPTracer(self, val.primal, val.tangent)
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
2018-11-17 18:03:33 -08:00
try:
jvp = primitive_jvps[primitive]
except KeyError as err:
2018-11-17 18:03:33 -08:00
raise NotImplementedError(
"Forward-mode differentiation rule for '{}' not implemented"
.format(primitive)) from err
2018-11-17 18:03:33 -08:00
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
if primitive.multiple_results:
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
else:
return JVPTracer(self, primal_out, tangent_out)
2018-11-17 18:03:33 -08:00
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
name = params.get('name', f.__name__)
params = dict(params, name=wrap_name(name, 'jvp'))
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out = primals + tangents
del primals, tangents
2018-11-17 18:03:33 -08:00
master = self.master
def todo(x):
n = len(x) // 2
primals, tangents = x[:n], x[n:]
2018-11-17 18:03:33 -08:00
trace = JVPTrace(master, core.cur_sublevel())
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo
2018-11-17 18:03:33 -08:00
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
# only differs from process_call in that it must update mapped_invars
# TODO de-duplicate code
assert map_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
new_mapped_invars = (*params['mapped_invars'],
*[m for m, t in zip(params['mapped_invars'], tangents)
if t is not zero])
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars)
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
post_process_map = post_process_call
def process_custom_jvp_call(self, _, __, f_jvp, tracers):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
primals_in = map(core.full_lower, primals_in)
tangents_in = map(instantiate_zeros, primals_in, tangents_in)
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
return map(partial(JVPTracer, self), primals_out, tangents_out)
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
tangents_in = map(instantiate_zeros, primals_in, tangents_in)
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
out_tree, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
avals_out=avals_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
2018-11-17 18:03:33 -08:00
def join(self, xt, yt):
xz, yz = xt is zero, yt is zero
if xz == yz:
2018-11-17 18:03:33 -08:00
return xt, yt
elif yz and not xz:
return xt, zeros_like_jaxval(xt)
elif xz and not yz:
return zeros_like_jaxval(yt), yt
2018-11-17 18:03:33 -08:00
else:
raise TypeError((xt, yt))
2018-11-17 18:03:33 -08:00
class JVPTracer(Tracer):
2019-01-16 16:51:54 +00:00
__slots__ = ['primal', 'tangent']
2018-11-17 18:03:33 -08:00
def __init__(self, trace, primal, tangent):
2019-05-10 15:52:12 -07:00
if not core.skip_checks:
_primal_tangent_shapes_match(primal, tangent)
self._trace = trace
2018-11-17 18:03:33 -08:00
self.primal = primal
self.tangent = tangent
@property
def aval(self):
# TODO(dougalm): add epsilon ball
return get_aval(self.primal)
def full_lower(self):
if self.tangent is zero:
return core.full_lower(self.primal)
else:
return self
2019-05-10 15:52:12 -07:00
def _primal_tangent_shapes_match(primal, tangent):
if tangent is not zero:
2019-05-10 15:52:12 -07:00
primal_aval = raise_to_shaped(get_aval(primal))
tangent_aval = raise_to_shaped(get_aval(tangent))
assert primal_aval == tangent_aval
2018-11-17 18:03:33 -08:00
# -------------------- Primitives --------------------
primitive_jvps : Dict[core.Primitive, Callable] = {}
2018-11-17 18:03:33 -08:00
primitive_transposes: Dict[core.Primitive, Callable] = {}
2018-11-17 18:03:33 -08:00
def deflinear(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
def linear_jvp(primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
if all(tangent is zero for tangent in tangents):
return val_out, zero
else:
tangents = map(instantiate_zeros, primals, tangents)
return val_out, primitive.bind(*tangents, **params)
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
def deflinear2(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
return zero if cotangent is zero else transpose_rule(cotangent, *args, **kwargs)
2018-11-17 18:03:33 -08:00
def defjvp(primitive, *jvprules):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
def standard_jvp(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
2019-02-20 12:36:18 -08:00
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero]
return val_out, functools.reduce(add_tangents, tangents_out, zero)
2018-11-17 18:03:33 -08:00
def defjvp2(primitive, *jvprules):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero)
return val_out, functools.reduce(add_tangents, tangents_out, zero)
2018-11-17 18:03:33 -08:00
def add_tangents(x, y):
if x is zero:
return y
elif y is zero:
return x
else:
return add_jaxvals(x, y)
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
assert isinstance(prim, Primitive)
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
defjvp(prim, lhs_jvp, rhs_jvp)
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
assert is_undefined_primal(x) ^ is_undefined_primal(y)
if is_undefined_primal(x):
2018-11-17 18:03:33 -08:00
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
return out, None
else:
out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)
return None, out
def defjvp_zero(primitive):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(zero_jvp, primitive)
def zero_jvp(primitive, primals, tangents, **params):
return primitive.bind(*primals, **params), zero
deflinear(zeros_like_p, lambda t: [zero])
2018-11-17 18:03:33 -08:00
deflinear(core.identity_p, lambda t: (t,))
deflinear(add_jaxvals_p, lambda t: (t, t))
def instantiate_zeros(example, tangent):
if tangent is zero:
return zeros_like_jaxval(example)
else:
return tangent
def instantiate_zeros_aval(aval, tangent):
if tangent is zero:
return zeros_like_aval(aval)
else:
return tangent
@lu.transformation_with_aux
def traceable(num_primals, in_tree_def, *primals_and_tangents):
new_primals = primals_and_tangents[:num_primals]
new_tangents = primals_and_tangents[num_primals:]
new_tangents = tree_unflatten(in_tree_def, new_tangents)
primal_out, tangent_out = yield (new_primals, new_tangents), {}
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
yield out_flat, tree_def
2018-11-17 18:03:33 -08:00
def call_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
2019-02-23 20:34:14 -08:00
def map_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
if not is_undefined_primal(x)],
*[True for x in ct if x is not zero])
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
mapped_invars=new_mapped_invars)
out_flat = primitive.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)
mapped_invars = params['mapped_invars'] # True for each mapped invar
# The freevars are being fanned out (not mapped). During transpose the
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
assert len(mapped_invars) == len(arg_cts)
arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct.sum(0)
for arg_ct, arg_mapped in zip(arg_cts, mapped_invars))
return arg_cts
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
return jaxpr_out, out_nonzeros()
@lu.transformation_with_aux
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
num_primals = len(nonzeros)
primals = list(primals_and_nztangents[:num_primals])
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
tangents = [next(nonzero_tangents) if nz else zero for nz in nonzeros]
primals_out, tangents_out = yield (primals, tangents), {}
out_nonzeros = [t is not zero for t in tangents_out]
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
def rearrange_binders(jaxpr: core.TypedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns)
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, jaxpr.literals, new_in_avals,
new_out_avals)
return new_typed_jaxpr
def _perm(primal_counts, tangent_counts, lst):
n = sum(primal_counts)
primals, tangents = lst[:n], lst[n:]
primal_groups = split_list(primals, primal_counts[:-1])
tangent_groups = split_list(tangents, tangent_counts[:-1])
return _interleave(primal_groups, tangent_groups)
def _interleave(xs, ys):
assert len(xs) == len(ys)
return [e for pair in zip(xs, ys) for l in pair for e in l]
custom_lin_p = core.Primitive('custom_lin')
custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
custom_lin_p.multiple_results = True
def _raise_custom_vjp_error_on_jvp(*_, **__):
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
"function.")
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
return [None] * num_res + cts_in_flat
primitive_transposes[custom_lin_p] = _custom_lin_transpose
# TODO(mattjj): delete everything below here (deprecated custom_transforms)
def defvjp_all(prim, custom_vjp):
# see https://github.com/google/jax/pull/636
name = prim.name
def fun_jvp(xs, ts, **params):
ts = map(instantiate_zeros, xs, ts)
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
if prim.multiple_results:
return primals, tangents
else:
primal, = primals
tangent, = tangents
return primal, tangent
primitive_jvps[prim] = fun_jvp
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
fun_jvp_p.multiple_results = True
def fun_jvp_partial_eval(trace, *tracers, **params):
primals, tangents = split_list(tracers, [len(tracers) // 2])
primals_out, vjp_py = custom_vjp(*primals, **params)
if not prim.multiple_results:
primals_out = [primals_out]
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
2020-03-30 13:49:56 -07:00
with core.initial_style_staging():
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
instantiate=True)
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
num_res=len(res), out_avals=out_avals)
return primals_out + tangents_out
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
fun_lin_p.multiple_results = True
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
def fun_lin_transpose(cts, *args, **kwargs):
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
res, _ = split_list(args, [num_res])
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
return [None] * num_res + outs
primitive_transposes[fun_lin_p] = fun_lin_transpose
def defvjp(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)
def defvjp2(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)