allow for recursive uses of custom_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Roy Frostig 2022-03-24 17:27:06 -07:00
parent f7df3ee9c4
commit a6a43e2715
5 changed files with 100 additions and 15 deletions

View File

@ -19,6 +19,7 @@ from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
@ -133,6 +134,16 @@ def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')
def make_transpose_from_thunk(thunk, lin_tree):
transpose_jaxpr, transpose_consts = thunk()
transpose_jaxpr = core.ClosedJaxpr(
pe.convert_constvars_jaxpr(transpose_jaxpr), ())
def transpose(res_arg, ct_out):
args_flat = tree_leaves((res_arg, ct_out))
ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
return tree_unflatten(lin_tree, ct_ins)
return transpose
### custom_transpose primitive and rules
@ -157,8 +168,14 @@ class CustomTransposePrimitive(core.Primitive):
# TODO(frostig,mattjj): consider keeping `call` as a named parameter
# instead of following this "call primitive" convention.
def get_bind_params(self, params):
assert 'call_jaxpr' in params
assert 'transpose_jaxpr_thunk' in params
new_params = dict(params)
return [new_params.pop('call')], new_params
new_params['transpose'] = make_transpose_from_thunk(
new_params.pop('transpose_jaxpr_thunk'),
new_params['lin_tree'])
call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr')))
return [call], new_params
# TODO(frostig,mattjj): reinstate checks
@ -167,7 +184,16 @@ def custom_transpose_typecheck(*avals, **params):
def custom_transpose_transpose_rule(
cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree):
cts, *args, out_types, res_tree, lin_tree, out_tree, **params):
if 'transpose_jaxpr_thunk' in params:
assert 'call_jaxpr' in params
transpose = make_transpose_from_thunk(
params['transpose_jaxpr_thunk'], lin_tree)
else:
assert 'call' in params
transpose = params['transpose']
call_in_tree = treedef_tuple((res_tree, lin_tree))
# TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect

View File

@ -805,7 +805,7 @@ def switch(index, branches: Sequence[Callable], *operands,
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel):
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
``cond()`` has equivalent semantics to this Python implementation::
@ -865,6 +865,12 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
return false_fun(*operands)
ops, ops_tree = tree_flatten(operands)
if linear is None:
linear_ops = [False] * len(ops)
else:
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
@ -878,10 +884,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
index = lax.convert_element_type(pred, np.int32)
linear = (False,) * (len(consts) + len(ops))
linear = [False] * len(consts) + linear_ops
out = cond_p.bind(
index, *consts, *ops,
branches=(false_jaxpr, true_jaxpr), linear=linear)
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
return tree_unflatten(out_tree, out)
@api_boundary

View File

@ -164,8 +164,10 @@ def recast_to_float0(primal, tangent):
else:
return tangent
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in):
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
# errors if you will)
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
consts, primals_in, cotangents_in):
if all(type(ct) is Zero for ct in cotangents_in):
return map(lambda v: Zero(v.aval), jaxpr.invars)

View File

@ -1626,10 +1626,12 @@ class DynamicJaxprTrace(core.Trace):
transpose_flat, in_tree2 = flatten_fun_nokwargs(
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic(
transpose_flat, self.main, in_avals_t)
closed_transpose_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(transpose_jaxpr), ())
main_ = ref(self.main)
# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
transpose_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(
transpose_flat, main_(), in_avals_t)[::2])
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
@ -1637,9 +1639,9 @@ class DynamicJaxprTrace(core.Trace):
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_call_jaxpr,
transpose_jaxpr=(closed_transpose_jaxpr,
transpose_consts),
num_consts=len(call_consts)),
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
out_types=out_types, res_tree=res_tree,
lin_tree=lin_tree, out_tree=out_tree),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers

View File

@ -6891,7 +6891,6 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(f_t(x), g_t(x))
def test_jit_recursive(self):
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@ -6913,6 +6912,56 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))
def test_cond(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * t / r
return x + fn(y, x)
def cond_wrap(f):
return lambda i, x: lax.cond(i > 0, f, lambda x: x, x,
linear=(True,))
i = 7.
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = partial(cond_wrap(f_), i)
g_t = transpose_unary(g_, x)
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))
def test_cond_recursive(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * fn(r, t)
return x + fn(y, x)
def cond_wrap(f):
return lambda i, x: lax.cond(i > 0, f, lambda x: x, x,
linear=(True,))
i = 7.
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = partial(cond_wrap(f_), i)
g_t = transpose_unary(g_, x)
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))
class CustomVmapTest(jtu.JaxTestCase):