mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
introduce linear_call
for custom transposition.
This adds a primitive with a corresponding traceable function in `custom_derivatives` that takes a callee and its transpose, both functions. When the primitive is encountered during transposition, the given transpose function is invoked instead of transpose-transforming the callee. The invocation of the custom transposition is itself done via a `linear_call`, with the original callee set as the transpose. This maintains, in particular, that transposing twice is an identity.
This commit is contained in:
parent
a0c5a80971
commit
912cc87a3d
@ -73,7 +73,7 @@ from .interpreters import masking
|
||||
from .interpreters import invertible_ad as iad
|
||||
from .interpreters.invertible_ad import custom_ivjp
|
||||
from .custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
custom_vjp)
|
||||
custom_vjp, linear_call)
|
||||
from .config import flags, config, bool_env
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
@ -21,8 +21,9 @@ from typing import Callable, Sequence, Tuple, Any
|
||||
from . import core
|
||||
from . import dtypes
|
||||
from . import linear_util as lu
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap,
|
||||
treedef_is_leaf, register_pytree_node_class)
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
tree_multimap, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class)
|
||||
from ._src.util import cache, safe_zip, safe_map, split_list
|
||||
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
|
||||
from .core import raise_to_shaped
|
||||
@ -55,6 +56,9 @@ def _initial_style_jaxpr(fun, in_avals):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return jaxpr, consts
|
||||
|
||||
def _close_jaxpr(jaxpr):
|
||||
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
def _initial_style_staging() -> bool:
|
||||
return core.thread_local_state.trace_state.initial_style
|
||||
|
||||
@ -951,3 +955,166 @@ def partition_list(choice, lst):
|
||||
|
||||
def abstractify(x):
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
|
||||
### Custom transposition
|
||||
|
||||
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
||||
linear_args):
|
||||
"""Call a linear function, with a custom implementation for its transpose.
|
||||
|
||||
The type signatures of ``fun`` and ``fun_transpose`` are:
|
||||
|
||||
.. code-block:: haskell
|
||||
|
||||
fun :: r -> a -o b
|
||||
fun_transpose :: r -> b -o a
|
||||
|
||||
where the ``-o`` arrow indicates a linear function, ``r`` is the
|
||||
residual input type and ``a`` is the linear input type.
|
||||
|
||||
The functions ``fun`` and ``fun_transpose`` are coupled as
|
||||
transposes of one another. Specifically, the transpose of a
|
||||
``linear_call`` primitive is another ``linear_call`` to
|
||||
``fun_transpose``, with ``fun`` as its custom transposition.
|
||||
|
||||
For example, if::
|
||||
|
||||
def f(r, x):
|
||||
return x / r
|
||||
|
||||
def t(r, t):
|
||||
return t / r
|
||||
|
||||
def div_add(denom, x):
|
||||
return x + linear_call(f, t, denom, x)
|
||||
|
||||
def transpose(f, x_example):
|
||||
def transposed(y):
|
||||
x, = jax.linear_transpose(f, x_example)(y)
|
||||
return x
|
||||
return transposed
|
||||
|
||||
Then:
|
||||
|
||||
>>> div_add(9., 3.)
|
||||
12.0
|
||||
>>> transpose(partial(div_add, 3.), 1.)(18.) # custom
|
||||
24.0
|
||||
>>> transpose(lambda x: x + x / 3., 1.)(18.) # reference
|
||||
24.0
|
||||
|
||||
The above definition of ``f`` illustrates the purpose of a residual
|
||||
argument: division is linear in one of its inputs (the dividend
|
||||
``x``) but not the other (the divisor ``r``).
|
||||
|
||||
As another example, if::
|
||||
|
||||
def custom_id(x):
|
||||
def f(_, x): return x
|
||||
def t(_, t): return 7.
|
||||
return linear_call(f, t, (), x)
|
||||
|
||||
Then:
|
||||
|
||||
>>> custom_id(1.)
|
||||
1.0
|
||||
>>> transpose(custom_id, 1.)(1.)
|
||||
7.0
|
||||
>>> transpose(transpose(custom_id, 1.), 1.)(1.)
|
||||
1.0
|
||||
>>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
|
||||
7.0
|
||||
|
||||
Args:
|
||||
fun: a Python callable specifying a linear function. It should
|
||||
take two arguments: one of "residual" inputs (type ``r``),
|
||||
i.e. inputs in which the function is not necessarly linear, and
|
||||
one of "linear" inputs (type ``a``). It should return output
|
||||
whose components are linear in the linear input (type ``b``).
|
||||
fun_transpose: a Python callable specifying a structurally linear
|
||||
function that is the transpose of ``fun`` with respect to its
|
||||
linear inputs. Its first argument is the same residual inputs
|
||||
(``r``) as ``fun``. Its second argument is of type
|
||||
``b``. Finally, its output is of type ``a`` and each of its
|
||||
component are linear in its second argument (the ``b`` inputs).
|
||||
residual_args: Argument in which ``fun`` and ``fun_transpose`` are
|
||||
not necessarily linear. Not involved in transposition.
|
||||
linear_args: Argument in which ``fun`` and ``fun_transpose`` are
|
||||
linear and with respect to which the two are transposes.
|
||||
|
||||
Returns:
|
||||
The call result, i.e. ``fun(residual_args, linear_args)``.
|
||||
|
||||
"""
|
||||
operands_res, res_tree = tree_flatten(residual_args)
|
||||
operands_lin, lin_tree = tree_flatten(linear_args)
|
||||
|
||||
f_in_tree = treedef_tuple((res_tree, lin_tree))
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
|
||||
|
||||
res_avals = map(abstractify, operands_res)
|
||||
lin_avals = map(abstractify, operands_lin)
|
||||
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
|
||||
f_jaxpr = _close_jaxpr(f_jaxpr)
|
||||
out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)
|
||||
|
||||
t_in_tree = treedef_tuple((res_tree, out_tree()))
|
||||
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
|
||||
|
||||
t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
|
||||
t_jaxpr = _close_jaxpr(t_jaxpr)
|
||||
|
||||
if t_out_tree() != lin_tree:
|
||||
raise TypeError(
|
||||
'transpose output pytree structure must match that of linear inputs, '
|
||||
f'got output structure {t_out_tree()} '
|
||||
f'and input structure {lin_tree}.')
|
||||
|
||||
out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin,
|
||||
callee=f_jaxpr,
|
||||
transpose=t_jaxpr,
|
||||
num_callee_consts=len(f_consts),
|
||||
num_transpose_consts=len(t_consts),
|
||||
num_res=len(operands_res))
|
||||
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
def _linear_call_impl(*args, callee, transpose, num_callee_consts,
|
||||
num_transpose_consts, num_res):
|
||||
del transpose
|
||||
consts, _, operands_res, operands_lin = split_list(
|
||||
args, [num_callee_consts, num_transpose_consts, num_res])
|
||||
return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin)
|
||||
|
||||
def _linear_call_transpose_rule(cts, *args, callee, transpose,
|
||||
num_callee_consts,
|
||||
num_transpose_consts, num_res):
|
||||
f_consts, t_consts, operands_res, operands_lin = split_list(
|
||||
args, [num_callee_consts, num_transpose_consts, num_res])
|
||||
_, _, cts_avals = split_list(
|
||||
transpose.in_avals, [num_transpose_consts, num_res])
|
||||
|
||||
assert all(ad.is_undefined_primal(x) for x in operands_lin)
|
||||
assert all(not ad.is_undefined_primal(x) for x in operands_res)
|
||||
|
||||
cts = [zeros_like_aval(a) if type(ct) is Zero else ct
|
||||
for ct, a in zip(cts, cts_avals)]
|
||||
|
||||
cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts,
|
||||
callee=transpose,
|
||||
transpose=callee,
|
||||
num_callee_consts=len(t_consts),
|
||||
num_transpose_consts=len(f_consts),
|
||||
num_res=len(operands_res))
|
||||
|
||||
return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
|
||||
|
||||
def _linear_call_abstract_eval(*args, **kwargs):
|
||||
return map(core.raise_to_shaped, kwargs['callee'].out_avals)
|
||||
|
||||
linear_call_p = core.Primitive('linear_call')
|
||||
linear_call_p.multiple_results = True
|
||||
linear_call_p.def_impl(_linear_call_impl)
|
||||
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
|
||||
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
|
||||
|
@ -4672,6 +4672,118 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(g_x, 17. * jnp.ones(2), check_dtypes=False)
|
||||
|
||||
|
||||
class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
def transpose(self, f, x_example):
|
||||
def transposed(y):
|
||||
x, = api.linear_transpose(f, x_example)(y)
|
||||
return x
|
||||
return transposed
|
||||
|
||||
def test_linear_call(self):
|
||||
def f(x, y):
|
||||
def fn(r, x): return x / r
|
||||
def tp(r, t): return t / r
|
||||
return x + api.linear_call(fn, tp, y, x)
|
||||
|
||||
def f_ref(x, y):
|
||||
return x + x / y
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
self.assertAllClose(f(x, y), f_ref(x, y))
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
f1_ref = lambda x: f_ref(x, y)
|
||||
self.assertAllClose(self.transpose(f1, x)(x),
|
||||
self.transpose(f1_ref, x)(x))
|
||||
|
||||
def test_linear_call_incorrect_transpose(self):
|
||||
def f(x, y):
|
||||
def fn(r, x): return x / r
|
||||
def tp(r, t): return t / (2. * r) # nb: not the true transpose
|
||||
return x + api.linear_call(fn, tp, y, x)
|
||||
|
||||
def f_ref(x, y):
|
||||
return x + x / y
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
self.assertAllClose(f(x, y), f_ref(x, y))
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
|
||||
self.assertAllClose(self.transpose(f1, x)(x),
|
||||
self.transpose(f1_ref, x)(x))
|
||||
|
||||
def test_linear_call_transpose_transpose_transpose(self):
|
||||
def fn(r, x): return x / r
|
||||
def tp(r, t): return t / (2. * r) # nb: untrue transpose
|
||||
def f_(x, y):
|
||||
return x + api.linear_call(fn, tp, y, x)
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
f = lambda x: f_(x, y)
|
||||
ft = self.transpose(f, x)
|
||||
ftt = self.transpose(ft, x)
|
||||
fttt = self.transpose(ftt, x)
|
||||
self.assertAllClose(ft(x), x + tp(y, x))
|
||||
self.assertAllClose(f(x), ftt(x))
|
||||
self.assertAllClose(ft(x), fttt(x))
|
||||
|
||||
def test_linear_call_scalar_to_vector(self):
|
||||
def f(c, x):
|
||||
def fn(_, x):
|
||||
return [x, x]
|
||||
|
||||
def tp(_, t):
|
||||
t1, t2 = t
|
||||
return t1 + t2
|
||||
|
||||
return api.linear_call(fn, tp, (), c * x)
|
||||
|
||||
def f_ref(c, x):
|
||||
return [c * x, c * x]
|
||||
|
||||
c, x = 2., 3.
|
||||
t = [4., 5.]
|
||||
self.assertAllClose(f(c, x), f_ref(c, x))
|
||||
self.assertAllClose(self.transpose(partial(f, c), x)(t),
|
||||
self.transpose(partial(f_ref, c), x)(t))
|
||||
|
||||
def test_linear_call_nested(self):
|
||||
# identity function with an untrue transpose of 0
|
||||
def id_(x):
|
||||
def f(_, x): return x
|
||||
def t(_, t): return 0.
|
||||
return api.linear_call(f, t, (), x)
|
||||
|
||||
# identity function with an untrue transpose of 7, and where both
|
||||
# forward and transpose have custom transpositions that should
|
||||
# never end up invoked.
|
||||
def f(x):
|
||||
def f_(_, x): return id_(x)
|
||||
def t_(_, t): return id_(7.)
|
||||
return api.linear_call(f_, t_, (), x)
|
||||
|
||||
x = 5.
|
||||
id_t = self.transpose(id_, x)
|
||||
id_tt = self.transpose(id_t, x)
|
||||
ft = self.transpose(f, x)
|
||||
ftt = self.transpose(ft, x)
|
||||
fttt = self.transpose(ftt, x)
|
||||
|
||||
self.assertAllClose(id_(x), x)
|
||||
self.assertAllClose(id_t(x), 0.)
|
||||
self.assertAllClose(id_tt(x), x)
|
||||
|
||||
self.assertAllClose(f(x), x)
|
||||
self.assertAllClose(ft(x), 7.)
|
||||
self.assertAllClose(ftt(x), x)
|
||||
self.assertAllClose(fttt(x), 7.)
|
||||
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
def test_invertible_basic(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user