introduce linear_call for custom transposition.

This adds a primitive with a corresponding traceable function in
`custom_derivatives` that takes a callee and its transpose, both
functions. When the primitive is encountered during transposition, the
given transpose function is invoked instead of transpose-transforming
the callee. The invocation of the custom transposition is itself done
via a `linear_call`, with the original callee set as the transpose.
This maintains, in particular, that transposing twice is an identity.
This commit is contained in:
Roy Frostig 2021-02-12 12:56:15 -08:00
parent a0c5a80971
commit 912cc87a3d
3 changed files with 282 additions and 3 deletions

View File

@ -73,7 +73,7 @@ from .interpreters import masking
from .interpreters import invertible_ad as iad
from .interpreters.invertible_ad import custom_ivjp
from .custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp)
custom_vjp, linear_call)
from .config import flags, config, bool_env
traceback_util.register_exclusion(__file__)

View File

@ -21,8 +21,9 @@ from typing import Callable, Sequence, Tuple, Any
from . import core
from . import dtypes
from . import linear_util as lu
from .tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap,
treedef_is_leaf, register_pytree_node_class)
from .tree_util import (tree_flatten, tree_unflatten, tree_map,
tree_multimap, treedef_is_leaf, treedef_tuple,
register_pytree_node_class)
from ._src.util import cache, safe_zip, safe_map, split_list
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
from .core import raise_to_shaped
@ -55,6 +56,9 @@ def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr, consts
def _close_jaxpr(jaxpr):
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
def _initial_style_staging() -> bool:
return core.thread_local_state.trace_state.initial_style
@ -951,3 +955,166 @@ def partition_list(choice, lst):
def abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
### Custom transposition
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
linear_args):
"""Call a linear function, with a custom implementation for its transpose.
The type signatures of ``fun`` and ``fun_transpose`` are:
.. code-block:: haskell
fun :: r -> a -o b
fun_transpose :: r -> b -o a
where the ``-o`` arrow indicates a linear function, ``r`` is the
residual input type and ``a`` is the linear input type.
The functions ``fun`` and ``fun_transpose`` are coupled as
transposes of one another. Specifically, the transpose of a
``linear_call`` primitive is another ``linear_call`` to
``fun_transpose``, with ``fun`` as its custom transposition.
For example, if::
def f(r, x):
return x / r
def t(r, t):
return t / r
def div_add(denom, x):
return x + linear_call(f, t, denom, x)
def transpose(f, x_example):
def transposed(y):
x, = jax.linear_transpose(f, x_example)(y)
return x
return transposed
Then:
>>> div_add(9., 3.)
12.0
>>> transpose(partial(div_add, 3.), 1.)(18.) # custom
24.0
>>> transpose(lambda x: x + x / 3., 1.)(18.) # reference
24.0
The above definition of ``f`` illustrates the purpose of a residual
argument: division is linear in one of its inputs (the dividend
``x``) but not the other (the divisor ``r``).
As another example, if::
def custom_id(x):
def f(_, x): return x
def t(_, t): return 7.
return linear_call(f, t, (), x)
Then:
>>> custom_id(1.)
1.0
>>> transpose(custom_id, 1.)(1.)
7.0
>>> transpose(transpose(custom_id, 1.), 1.)(1.)
1.0
>>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
7.0
Args:
fun: a Python callable specifying a linear function. It should
take two arguments: one of "residual" inputs (type ``r``),
i.e. inputs in which the function is not necessarly linear, and
one of "linear" inputs (type ``a``). It should return output
whose components are linear in the linear input (type ``b``).
fun_transpose: a Python callable specifying a structurally linear
function that is the transpose of ``fun`` with respect to its
linear inputs. Its first argument is the same residual inputs
(``r``) as ``fun``. Its second argument is of type
``b``. Finally, its output is of type ``a`` and each of its
component are linear in its second argument (the ``b`` inputs).
residual_args: Argument in which ``fun`` and ``fun_transpose`` are
not necessarily linear. Not involved in transposition.
linear_args: Argument in which ``fun`` and ``fun_transpose`` are
linear and with respect to which the two are transposes.
Returns:
The call result, i.e. ``fun(residual_args, linear_args)``.
"""
operands_res, res_tree = tree_flatten(residual_args)
operands_lin, lin_tree = tree_flatten(linear_args)
f_in_tree = treedef_tuple((res_tree, lin_tree))
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
res_avals = map(abstractify, operands_res)
lin_avals = map(abstractify, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)
t_in_tree = treedef_tuple((res_tree, out_tree()))
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
t_jaxpr = _close_jaxpr(t_jaxpr)
if t_out_tree() != lin_tree:
raise TypeError(
'transpose output pytree structure must match that of linear inputs, '
f'got output structure {t_out_tree()} '
f'and input structure {lin_tree}.')
out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin,
callee=f_jaxpr,
transpose=t_jaxpr,
num_callee_consts=len(f_consts),
num_transpose_consts=len(t_consts),
num_res=len(operands_res))
return tree_unflatten(out_tree(), out)
def _linear_call_impl(*args, callee, transpose, num_callee_consts,
num_transpose_consts, num_res):
del transpose
consts, _, operands_res, operands_lin = split_list(
args, [num_callee_consts, num_transpose_consts, num_res])
return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin)
def _linear_call_transpose_rule(cts, *args, callee, transpose,
num_callee_consts,
num_transpose_consts, num_res):
f_consts, t_consts, operands_res, operands_lin = split_list(
args, [num_callee_consts, num_transpose_consts, num_res])
_, _, cts_avals = split_list(
transpose.in_avals, [num_transpose_consts, num_res])
assert all(ad.is_undefined_primal(x) for x in operands_lin)
assert all(not ad.is_undefined_primal(x) for x in operands_res)
cts = [zeros_like_aval(a) if type(ct) is Zero else ct
for ct, a in zip(cts, cts_avals)]
cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts,
callee=transpose,
transpose=callee,
num_callee_consts=len(t_consts),
num_transpose_consts=len(f_consts),
num_res=len(operands_res))
return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
def _linear_call_abstract_eval(*args, **kwargs):
return map(core.raise_to_shaped, kwargs['callee'].out_avals)
linear_call_p = core.Primitive('linear_call')
linear_call_p.multiple_results = True
linear_call_p.def_impl(_linear_call_impl)
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule

View File

@ -4672,6 +4672,118 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(g_x, 17. * jnp.ones(2), check_dtypes=False)
class CustomTransposeTest(jtu.JaxTestCase):
def transpose(self, f, x_example):
def transposed(y):
x, = api.linear_transpose(f, x_example)(y)
return x
return transposed
def test_linear_call(self):
def f(x, y):
def fn(r, x): return x / r
def tp(r, t): return t / r
return x + api.linear_call(fn, tp, y, x)
def f_ref(x, y):
return x + x / y
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), f_ref(x, y))
f1 = lambda x: f(x, y)
f1_ref = lambda x: f_ref(x, y)
self.assertAllClose(self.transpose(f1, x)(x),
self.transpose(f1_ref, x)(x))
def test_linear_call_incorrect_transpose(self):
def f(x, y):
def fn(r, x): return x / r
def tp(r, t): return t / (2. * r) # nb: not the true transpose
return x + api.linear_call(fn, tp, y, x)
def f_ref(x, y):
return x + x / y
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), f_ref(x, y))
f1 = lambda x: f(x, y)
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
self.assertAllClose(self.transpose(f1, x)(x),
self.transpose(f1_ref, x)(x))
def test_linear_call_transpose_transpose_transpose(self):
def fn(r, x): return x / r
def tp(r, t): return t / (2. * r) # nb: untrue transpose
def f_(x, y):
return x + api.linear_call(fn, tp, y, x)
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f = lambda x: f_(x, y)
ft = self.transpose(f, x)
ftt = self.transpose(ft, x)
fttt = self.transpose(ftt, x)
self.assertAllClose(ft(x), x + tp(y, x))
self.assertAllClose(f(x), ftt(x))
self.assertAllClose(ft(x), fttt(x))
def test_linear_call_scalar_to_vector(self):
def f(c, x):
def fn(_, x):
return [x, x]
def tp(_, t):
t1, t2 = t
return t1 + t2
return api.linear_call(fn, tp, (), c * x)
def f_ref(c, x):
return [c * x, c * x]
c, x = 2., 3.
t = [4., 5.]
self.assertAllClose(f(c, x), f_ref(c, x))
self.assertAllClose(self.transpose(partial(f, c), x)(t),
self.transpose(partial(f_ref, c), x)(t))
def test_linear_call_nested(self):
# identity function with an untrue transpose of 0
def id_(x):
def f(_, x): return x
def t(_, t): return 0.
return api.linear_call(f, t, (), x)
# identity function with an untrue transpose of 7, and where both
# forward and transpose have custom transpositions that should
# never end up invoked.
def f(x):
def f_(_, x): return id_(x)
def t_(_, t): return id_(7.)
return api.linear_call(f_, t_, (), x)
x = 5.
id_t = self.transpose(id_, x)
id_tt = self.transpose(id_t, x)
ft = self.transpose(f, x)
ftt = self.transpose(ft, x)
fttt = self.transpose(ftt, x)
self.assertAllClose(id_(x), x)
self.assertAllClose(id_t(x), 0.)
self.assertAllClose(id_tt(x), x)
self.assertAllClose(f(x), x)
self.assertAllClose(ft(x), 7.)
self.assertAllClose(ftt(x), x)
self.assertAllClose(fttt(x), 7.)
class InvertibleADTest(jtu.JaxTestCase):
def test_invertible_basic(self):