2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import functools
|
2019-02-23 20:34:14 -08:00
|
|
|
import itertools as it
|
2020-03-18 17:06:05 -04:00
|
|
|
from typing import Any, Callable, Dict
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import partial_eval as pe
|
|
|
|
from .. import core as core
|
2019-07-26 16:48:17 -04:00
|
|
|
from ..core import Trace, Tracer, new_master, get_aval, call_p, Primitive, Literal
|
2019-05-07 08:52:08 -07:00
|
|
|
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
2020-01-05 04:35:34 +01:00
|
|
|
zeros_like_p, zero)
|
2019-04-23 17:47:28 -07:00
|
|
|
from ..abstract_arrays import raise_to_shaped
|
2020-01-26 23:27:56 -08:00
|
|
|
from ..util import unzip2, safe_map, safe_zip, partial, split_list, wrap_name
|
2020-01-05 04:35:34 +01:00
|
|
|
from ..tree_util import register_pytree_node
|
|
|
|
from .. import linear_util as lu
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
2019-07-26 23:17:21 -04:00
|
|
|
from ..tree_util import tree_flatten, tree_unflatten
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
zip = safe_zip
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(jvp_subtrace(fun), instantiate)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, aux = jvp_subtrace_aux(fun)
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(fun, instantiate), aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def jvpfun(instantiate, primals, tangents):
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JVPTrace) as master:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, out_tangents = yield (master, primals, tangents), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
del master
|
2019-07-27 15:46:14 -07:00
|
|
|
if type(instantiate) is bool:
|
|
|
|
instantiate = [instantiate] * len(out_tangents)
|
|
|
|
out_tangents = [instantiate_zeros(x, t) if inst else t for x, t, inst
|
|
|
|
in zip(out_primals, out_tangents, instantiate)]
|
|
|
|
yield out_primals, out_tangents
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2018-11-17 18:03:33 -08:00
|
|
|
def jvp_subtrace(master, primals, tangents):
|
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
2020-01-29 16:23:27 -05:00
|
|
|
assert x._trace.level < trace.level
|
2019-09-09 17:47:15 -07:00
|
|
|
in_tracers = [JVPTracer(trace, x, t) if t is not zero else x
|
|
|
|
for x, t in zip(primals, tangents)]
|
|
|
|
ans = yield in_tracers, {}
|
2019-07-26 23:17:21 -04:00
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
yield unzip2([(out_tracer.primal, out_tracer.tangent)
|
|
|
|
for out_tracer in out_tracers])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def jvp_subtrace_aux(master, primals, tangents):
|
2019-03-07 14:08:02 -08:00
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
2020-01-29 16:23:27 -05:00
|
|
|
assert x._trace.level < trace.level
|
2019-04-10 22:09:14 -07:00
|
|
|
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
ans_tracers = map(trace.full_raise, ans)
|
|
|
|
aux_tracers = map(trace.full_raise, aux)
|
|
|
|
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
|
|
|
|
aux_primals, _ = unzip2((t.primal, t.tangent) for t in aux_tracers)
|
2020-01-06 18:08:00 -08:00
|
|
|
aux_primals = map(core.full_lower, aux_primals)
|
2019-07-27 15:46:14 -07:00
|
|
|
yield (out_primals, out_tangents), aux_primals
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
def linearize(traceable, *primals, **kwargs):
|
|
|
|
has_aux = kwargs.pop('has_aux', False)
|
|
|
|
if not has_aux:
|
2019-07-26 23:17:21 -04:00
|
|
|
jvpfun = jvp(traceable)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
|
|
|
jvpfun, aux = jvp(traceable, has_aux=True)
|
2019-07-26 23:17:21 -04:00
|
|
|
|
|
|
|
in_pvals = (tuple(pe.PartialVal((None, p)) for p in primals)
|
|
|
|
+ tuple(pe.PartialVal((get_aval(p).at_least_vspace(), core.unit))
|
|
|
|
for p in primals))
|
|
|
|
_, in_tree = tree_flatten(((primals, primals), {}))
|
|
|
|
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|
|
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
|
|
|
|
pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
|
|
|
|
aval_primals, const_primals = unzip2(pval_primals)
|
|
|
|
assert all(aval_primal is None for aval_primal in aval_primals)
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-07-26 23:17:21 -04:00
|
|
|
return const_primals, pval_tangents, jaxpr, consts
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-26 23:17:21 -04:00
|
|
|
return const_primals, pval_tangents, jaxpr, consts, aux()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def vjp(traceable, primals, has_aux=False):
|
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
|
|
|
|
def vjp_(*cts):
|
|
|
|
cts = tuple(map(ignore_consts, cts, pvals))
|
|
|
|
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
2020-01-07 13:11:32 -08:00
|
|
|
arg_cts = backward_pass(jaxpr, consts, dummy_args, dummy_primals_and_cts)
|
2019-07-27 15:46:14 -07:00
|
|
|
arg_cts = arg_cts[len(primals):]
|
|
|
|
return map(instantiate_zeros, primals, arg_cts)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_, aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-03 22:24:46 -05:00
|
|
|
def ignore_consts(ct, pval):
|
|
|
|
aval, const = pval
|
|
|
|
if isinstance(aval, core.AbstractValue):
|
|
|
|
return ct
|
|
|
|
elif aval is None:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
def unpair_pval(pval):
|
|
|
|
aval, const = pval
|
|
|
|
const_1, const_2 = const
|
|
|
|
if aval is None:
|
|
|
|
return (None, const_1), (None, const_2)
|
|
|
|
else:
|
|
|
|
aval_1, aval_2 = aval
|
|
|
|
return (aval_1, const_1), (aval_2, const_2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-07 13:11:32 -08:00
|
|
|
def backward_pass(jaxpr: core.Jaxpr, consts, args, cotangents_in):
|
2019-11-22 10:53:11 -08:00
|
|
|
if all(ct is zero for ct in cotangents_in):
|
2020-01-07 13:11:32 -08:00
|
|
|
return [zero] * len(jaxpr.invars)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def write_cotangent(v, ct):
|
|
|
|
# assert v not in primal_env
|
2020-01-22 17:19:14 -08:00
|
|
|
if ct is not None and type(v) is not Literal:
|
2018-11-17 18:03:33 -08:00
|
|
|
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
|
|
|
|
|
|
|
def read_cotangent(v):
|
|
|
|
return ct_env.get(v, zero)
|
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def read_primal(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return primal_env.get(v, UndefinedPrimal(v.aval))
|
2019-04-25 10:43:50 -07:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def write_primal(v, val):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if not is_undefined_primal(val):
|
2019-05-01 15:47:01 -07:00
|
|
|
primal_env[v] = val
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primal_env: Dict[Any, Any] = {}
|
2019-11-22 10:53:11 -08:00
|
|
|
write_primal(core.unitvar, core.unit)
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_primal, jaxpr.constvars, consts)
|
|
|
|
map(write_primal, jaxpr.invars, args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def is_linear(var):
|
|
|
|
if type(var) is Literal:
|
|
|
|
return False
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return var not in primal_env
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
linear_eqns = []
|
|
|
|
for eqn in jaxpr.eqns:
|
2020-02-05 15:38:25 +01:00
|
|
|
if not eqn.primitive.call_primitive:
|
2019-11-22 10:53:11 -08:00
|
|
|
if any(is_linear(v) for v in eqn.invars):
|
|
|
|
linear_eqns.append(eqn)
|
|
|
|
else:
|
|
|
|
in_vals = map(read_primal, eqn.invars)
|
|
|
|
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_primal(eqn.outvars[0], ans)
|
|
|
|
else:
|
2020-02-11 15:56:53 -08:00
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
2020-01-07 13:11:32 -08:00
|
|
|
if any(is_linear(v) for v in eqn.invars):
|
2019-11-22 10:53:11 -08:00
|
|
|
linear_eqns.append(eqn)
|
2020-02-11 15:56:53 -08:00
|
|
|
if any(not is_linear(v) for v in eqn.invars):
|
|
|
|
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
|
|
|
|
map(read_primal, eqn.invars), params)
|
2019-11-27 14:28:13 -08:00
|
|
|
map(write_primal, eqn.outvars, ans)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
ct_env: Dict[Any, Any] = {}
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_cotangent, jaxpr.outvars, cotangents_in)
|
2019-11-22 10:53:11 -08:00
|
|
|
for eqn in linear_eqns[::-1]:
|
2019-07-27 15:46:14 -07:00
|
|
|
invals = map(read_primal, eqn.invars)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
cts_in = map(read_cotangent, eqn.outvars)
|
2019-04-25 10:43:50 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_in, = map(read_cotangent, eqn.outvars)
|
2020-02-05 15:38:25 +01:00
|
|
|
if eqn.primitive.call_primitive:
|
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
2020-01-07 13:11:32 -08:00
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(
|
2020-02-05 15:38:25 +01:00
|
|
|
params, call_jaxpr, invals, cts_in)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
|
|
|
|
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
|
|
|
|
map(write_cotangent, eqn.invars, cts_out)
|
2018-12-03 22:24:46 -05:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
cotangents_out = map(read_cotangent, jaxpr.invars)
|
2020-01-07 13:11:32 -08:00
|
|
|
return cotangents_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-02-03 20:58:56 +01:00
|
|
|
def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
all_args, in_tree_def = tree_flatten((in_vals,))
|
2020-01-05 04:35:34 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
|
2019-11-27 15:25:49 -08:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
|
|
|
out_flat = prim.bind(fun, *all_args, **params)
|
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
2020-02-03 20:58:56 +01:00
|
|
|
def _eval_primals(jaxpr, args):
|
2019-11-22 10:53:11 -08:00
|
|
|
primal_env = {}
|
|
|
|
|
|
|
|
def read_primal(v):
|
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return primal_env.get(v, UndefinedPrimal(v.aval))
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def write_primal(v, val):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if not is_undefined_primal(val):
|
2019-11-22 10:53:11 -08:00
|
|
|
primal_env[v] = val
|
|
|
|
|
|
|
|
def is_linear(var):
|
|
|
|
if type(var) is Literal:
|
|
|
|
return False
|
|
|
|
else:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return var not in primal_env
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
write_primal(core.unitvar, core.unit)
|
2020-02-03 20:58:56 +01:00
|
|
|
assert not jaxpr.constvars
|
2019-11-22 10:53:11 -08:00
|
|
|
map(write_primal, jaxpr.invars, args)
|
|
|
|
for eqn in jaxpr.eqns:
|
2020-02-05 15:38:25 +01:00
|
|
|
if not eqn.primitive.call_primitive:
|
2019-11-22 10:53:11 -08:00
|
|
|
if not any(is_linear(v) for v in eqn.invars):
|
|
|
|
in_vals = map(read_primal, eqn.invars)
|
|
|
|
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_primal(eqn.outvars[0], ans)
|
|
|
|
else:
|
2020-02-11 15:56:53 -08:00
|
|
|
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
|
|
|
if any(not is_linear(v) for v in eqn.invars):
|
|
|
|
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
|
|
|
|
map(read_primal, eqn.invars), params)
|
2019-11-27 15:25:49 -08:00
|
|
|
map(write_primal, eqn.outvars, ans)
|
2019-11-22 10:53:11 -08:00
|
|
|
return map(read_primal, jaxpr.outvars)
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
class UndefinedPrimal:
|
|
|
|
__slots__ = ['aval']
|
|
|
|
def __init__(self, aval):
|
|
|
|
self.aval = aval
|
|
|
|
def __repr__(self):
|
|
|
|
return 'UndefinedPrimal({})'.format(self.aval)
|
|
|
|
|
|
|
|
def is_undefined_primal(x):
|
|
|
|
return type(x) is UndefinedPrimal
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
register_pytree_node(UndefinedPrimal,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
lambda z: ((), z.aval),
|
|
|
|
lambda aval, _: UndefinedPrimal(aval))
|
2019-05-07 08:52:08 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def get_primitive_transpose(p):
|
|
|
|
try:
|
|
|
|
return primitive_transposes[p]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise NotImplementedError(
|
2020-03-09 22:06:12 +02:00
|
|
|
"Reverse-mode differentiation rule for '{}' not implemented".format(p)
|
|
|
|
) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class JVPTrace(Trace):
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
2019-01-24 16:08:03 -08:00
|
|
|
def sublift(self, val):
|
2018-11-17 18:03:33 -08:00
|
|
|
return JVPTracer(self, val.primal, val.tangent)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
jvp = primitive_jvps[primitive]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Forward-mode differentiation rule for '{}' not implemented"
|
2020-03-09 22:06:12 +02:00
|
|
|
.format(primitive)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
2019-07-27 15:46:14 -07:00
|
|
|
if primitive.multiple_results:
|
|
|
|
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
|
|
|
|
else:
|
|
|
|
return JVPTracer(self, primal_out, tangent_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
2018-11-17 18:03:33 -08:00
|
|
|
primals = [t.primal for t in tracers]
|
|
|
|
tangents = [t.tangent for t in tracers]
|
2019-07-27 15:46:14 -07:00
|
|
|
nonzero_tangents, in_tree_def = tree_flatten(tangents)
|
|
|
|
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), len(primals), in_tree_def)
|
2020-01-26 23:27:56 -08:00
|
|
|
name = params.get('name', f.__name__)
|
|
|
|
params = dict(params, name=wrap_name(name, 'jvp'))
|
2019-07-27 15:46:14 -07:00
|
|
|
result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
|
|
|
|
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
|
|
|
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
|
|
|
out = primals + tangents
|
|
|
|
del primals, tangents
|
2018-11-17 18:03:33 -08:00
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
n = len(x) // 2
|
|
|
|
primals, tangents = x[:n], x[n:]
|
2018-11-17 18:03:33 -08:00
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(partial(JVPTracer, trace), primals, tangents)
|
|
|
|
return out, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def join(self, xt, yt):
|
2019-07-27 15:46:14 -07:00
|
|
|
xz, yz = xt is zero, yt is zero
|
|
|
|
if xz == yz:
|
2018-11-17 18:03:33 -08:00
|
|
|
return xt, yt
|
2019-07-27 15:46:14 -07:00
|
|
|
elif yz and not xz:
|
|
|
|
return xt, zeros_like_jaxval(xt)
|
|
|
|
elif xz and not yz:
|
|
|
|
return zeros_like_jaxval(yt), yt
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError((xt, yt))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
class JVPTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __init__(self, trace, primal, tangent):
|
2019-05-10 15:52:12 -07:00
|
|
|
if not core.skip_checks:
|
|
|
|
_primal_tangent_shapes_match(primal, tangent)
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.primal = primal
|
|
|
|
self.tangent = tangent
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
# TODO(dougalm): add epsilon ball
|
|
|
|
return get_aval(self.primal)
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
if self.tangent is zero:
|
|
|
|
return core.full_lower(self.primal)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2019-05-10 15:52:12 -07:00
|
|
|
def _primal_tangent_shapes_match(primal, tangent):
|
2019-07-27 15:46:14 -07:00
|
|
|
if tangent is not zero:
|
2019-05-10 15:52:12 -07:00
|
|
|
primal_aval = raise_to_shaped(get_aval(primal))
|
|
|
|
tangent_aval = raise_to_shaped(get_aval(tangent))
|
|
|
|
assert primal_aval == tangent_aval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# -------------------- Primitives --------------------
|
|
|
|
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primitive_jvps: Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
primitive_transposes: Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def deflinear(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
|
|
|
|
|
|
|
def linear_jvp(primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
if all(tangent is zero for tangent in tangents):
|
|
|
|
return val_out, zero
|
|
|
|
else:
|
|
|
|
tangents = map(instantiate_zeros, primals, tangents)
|
|
|
|
return val_out, primitive.bind(*tangents, **params)
|
|
|
|
|
|
|
|
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
|
|
|
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
|
|
|
|
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def deflinear2(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
|
|
|
|
|
|
|
|
def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
|
|
|
|
return zero if cotangent is zero else transpose_rule(cotangent, *args, **kwargs)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def defjvp(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
|
|
|
|
|
|
|
|
|
|
|
|
def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
2019-02-20 12:36:18 -08:00
|
|
|
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero]
|
2020-01-08 13:17:55 -05:00
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defjvp2(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
|
|
|
|
|
|
|
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero)
|
2020-01-08 13:17:55 -05:00
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def add_tangents(x, y):
|
|
|
|
if x is zero:
|
|
|
|
return y
|
|
|
|
elif y is zero:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return add_jaxvals(x, y)
|
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def defvjp_all(prim, custom_vjp):
|
|
|
|
# see https://github.com/google/jax/pull/636
|
2019-04-23 17:47:28 -07:00
|
|
|
name = prim.name
|
|
|
|
|
2019-06-03 07:17:37 -07:00
|
|
|
def fun_jvp(xs, ts, **params):
|
2019-07-27 15:46:14 -07:00
|
|
|
ts = map(instantiate_zeros, xs, ts)
|
|
|
|
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
|
|
|
|
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
|
|
|
|
if prim.multiple_results:
|
|
|
|
return primals, tangents
|
|
|
|
else:
|
|
|
|
primal, = primals
|
|
|
|
tangent, = tangents
|
|
|
|
return primal, tangent
|
2019-04-23 17:47:28 -07:00
|
|
|
primitive_jvps[prim] = fun_jvp
|
|
|
|
|
|
|
|
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
2019-07-27 15:46:14 -07:00
|
|
|
fun_jvp_p.multiple_results = True
|
2019-06-03 07:17:37 -07:00
|
|
|
def fun_jvp_partial_eval(trace, *tracers, **params):
|
2019-07-27 15:46:14 -07:00
|
|
|
primals, tangents = split_list(tracers, [len(tracers) // 2])
|
|
|
|
primals_out, vjp_py = custom_vjp(*primals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
primals_out = [primals_out]
|
|
|
|
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
|
|
|
|
ct_pvals = [pe.PartialVal((aval, core.unit)) for aval in out_avals]
|
2020-01-05 04:35:34 +01:00
|
|
|
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
2019-07-27 15:46:14 -07:00
|
|
|
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
|
|
|
num_res=len(res), out_avals=out_avals)
|
|
|
|
return primals_out + tangents_out
|
2019-04-23 17:47:28 -07:00
|
|
|
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
|
|
|
|
|
|
|
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
2019-07-27 15:46:14 -07:00
|
|
|
fun_lin_p.multiple_results = True
|
|
|
|
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
|
|
|
|
def fun_lin_transpose(cts, *args, **kwargs):
|
|
|
|
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
|
|
|
|
res, _ = split_list(args, [num_res])
|
2019-08-25 19:59:50 -07:00
|
|
|
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
|
2020-01-07 13:11:32 -08:00
|
|
|
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
|
2019-07-27 15:46:14 -07:00
|
|
|
return [None] * num_res + outs
|
2019-04-23 17:47:28 -07:00
|
|
|
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
|
|
|
|
|
|
|
def defvjp(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
def defvjp2(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
|
|
|
assert isinstance(prim, Primitive)
|
|
|
|
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
|
|
|
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
|
|
|
|
defjvp(prim, lhs_jvp, rhs_jvp)
|
|
|
|
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
|
|
|
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
|
|
|
|
|
|
|
|
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert is_undefined_primal(x) ^ is_undefined_primal(y)
|
|
|
|
if is_undefined_primal(x):
|
2018-11-17 18:03:33 -08:00
|
|
|
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
|
|
|
|
return out, None
|
|
|
|
else:
|
|
|
|
out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)
|
|
|
|
return None, out
|
|
|
|
|
|
|
|
|
|
|
|
def defjvp_zero(primitive):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
|
|
|
|
|
|
|
def zero_jvp(primitive, primals, tangents, **params):
|
|
|
|
return primitive.bind(*primals, **params), zero
|
|
|
|
|
|
|
|
|
2018-12-11 09:18:38 -08:00
|
|
|
deflinear(zeros_like_p, lambda t: [zero])
|
2018-11-17 18:03:33 -08:00
|
|
|
deflinear(core.identity_p, lambda t: (t,))
|
|
|
|
deflinear(add_jaxvals_p, lambda t: (t, t))
|
|
|
|
|
|
|
|
def instantiate_zeros(example, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_jaxval(example)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2019-05-07 08:52:08 -07:00
|
|
|
def instantiate_zeros_aval(aval, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_aval(aval)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
|
|
|
new_primals = primals_and_tangents[:num_primals]
|
|
|
|
new_tangents = primals_and_tangents[num_primals:]
|
|
|
|
new_tangents = tree_unflatten(in_tree_def, new_tangents)
|
2019-04-10 22:09:14 -07:00
|
|
|
primal_out, tangent_out = yield (new_primals, new_tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
|
|
|
|
yield out_flat, tree_def
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
def call_transpose(primitive, params, call_jaxpr, args, ct):
|
2020-02-05 11:08:21 +01:00
|
|
|
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
2020-02-05 15:38:25 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
2020-01-26 23:27:56 -08:00
|
|
|
params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = primitive.bind(fun, *all_args, **params)
|
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
2019-11-22 10:53:11 -08:00
|
|
|
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
def map_transpose(primitive, params, call_jaxpr, args, ct):
|
2020-02-05 11:08:21 +01:00
|
|
|
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
2020-02-05 15:38:25 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
2020-01-26 23:27:56 -08:00
|
|
|
params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = primitive.bind(fun, *all_args, **params)
|
2020-01-07 13:11:32 -08:00
|
|
|
arg_cts = tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
|
|
|
mapped_invars = params['mapped_invars'] # True for each mapped invar
|
|
|
|
# The freevars are being fanned out (not mapped). During transpose the
|
|
|
|
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
|
|
|
|
assert len(mapped_invars) == len(arg_cts)
|
|
|
|
arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct.sum(0)
|
|
|
|
for arg_ct, arg_mapped in zip(arg_cts, mapped_invars))
|
|
|
|
|
|
|
|
return arg_cts
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2019-05-10 08:58:05 -07:00
|
|
|
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert len(jaxpr.in_avals) == len(nonzeros)
|
2020-01-05 04:35:34 +01:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-05-10 08:58:05 -07:00
|
|
|
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
2019-07-27 15:46:14 -07:00
|
|
|
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
|
|
|
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
|
|
|
pvals = [pe.PartialVal((aval, core.unit)) for aval in avals_in]
|
|
|
|
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
|
|
|
avals_out, _ = unzip2(pvals_out)
|
|
|
|
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
2019-05-10 08:20:40 -07:00
|
|
|
return jaxpr_out, out_nonzeros()
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
|
|
|
|
num_primals = len(nonzeros)
|
|
|
|
primals = list(primals_and_nztangents[:num_primals])
|
|
|
|
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
|
|
|
|
tangents = [next(nonzero_tangents) if nz else zero for nz in nonzeros]
|
|
|
|
primals_out, tangents_out = yield (primals, tangents), {}
|
|
|
|
out_nonzeros = [t is not zero for t in tangents_out]
|
|
|
|
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
|
|
|
|
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
|
|
|
|
|
2020-01-07 13:11:32 -08:00
|
|
|
def rearrange_binders(jaxpr: core.TypedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
2019-07-27 15:46:14 -07:00
|
|
|
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
|
|
|
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
2020-01-07 13:11:32 -08:00
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
2019-07-27 15:46:14 -07:00
|
|
|
new_invars, new_outvars, jaxpr.jaxpr.eqns)
|
|
|
|
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
|
|
|
|
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)
|
|
|
|
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, jaxpr.literals, new_in_avals,
|
|
|
|
new_out_avals)
|
|
|
|
return new_typed_jaxpr
|
|
|
|
|
|
|
|
def _perm(primal_counts, tangent_counts, lst):
|
|
|
|
n = sum(primal_counts)
|
|
|
|
primals, tangents = lst[:n], lst[n:]
|
|
|
|
primal_groups = split_list(primals, primal_counts[:-1])
|
|
|
|
tangent_groups = split_list(tangents, tangent_counts[:-1])
|
|
|
|
return _interleave(primal_groups, tangent_groups)
|
|
|
|
|
|
|
|
def _interleave(xs, ys):
|
|
|
|
assert len(xs) == len(ys)
|
|
|
|
return [e for pair in zip(xs, ys) for l in pair for e in l]
|