mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
allow for recursive uses of custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
f7df3ee9c4
commit
a6a43e2715
@ -19,6 +19,7 @@ from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
|
||||
tree_structure, treedef_tuple, tree_unflatten)
|
||||
@ -133,6 +134,16 @@ def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
|
||||
f'Transpose rule output: {rule_out_tree}\n'
|
||||
f'Linear primal inputs: {lin_tree}')
|
||||
|
||||
def make_transpose_from_thunk(thunk, lin_tree):
|
||||
transpose_jaxpr, transpose_consts = thunk()
|
||||
transpose_jaxpr = core.ClosedJaxpr(
|
||||
pe.convert_constvars_jaxpr(transpose_jaxpr), ())
|
||||
def transpose(res_arg, ct_out):
|
||||
args_flat = tree_leaves((res_arg, ct_out))
|
||||
ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
|
||||
return tree_unflatten(lin_tree, ct_ins)
|
||||
return transpose
|
||||
|
||||
|
||||
### custom_transpose primitive and rules
|
||||
|
||||
@ -157,8 +168,14 @@ class CustomTransposePrimitive(core.Primitive):
|
||||
# TODO(frostig,mattjj): consider keeping `call` as a named parameter
|
||||
# instead of following this "call primitive" convention.
|
||||
def get_bind_params(self, params):
|
||||
assert 'call_jaxpr' in params
|
||||
assert 'transpose_jaxpr_thunk' in params
|
||||
new_params = dict(params)
|
||||
return [new_params.pop('call')], new_params
|
||||
new_params['transpose'] = make_transpose_from_thunk(
|
||||
new_params.pop('transpose_jaxpr_thunk'),
|
||||
new_params['lin_tree'])
|
||||
call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr')))
|
||||
return [call], new_params
|
||||
|
||||
|
||||
# TODO(frostig,mattjj): reinstate checks
|
||||
@ -167,7 +184,16 @@ def custom_transpose_typecheck(*avals, **params):
|
||||
|
||||
|
||||
def custom_transpose_transpose_rule(
|
||||
cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree):
|
||||
cts, *args, out_types, res_tree, lin_tree, out_tree, **params):
|
||||
|
||||
if 'transpose_jaxpr_thunk' in params:
|
||||
assert 'call_jaxpr' in params
|
||||
transpose = make_transpose_from_thunk(
|
||||
params['transpose_jaxpr_thunk'], lin_tree)
|
||||
else:
|
||||
assert 'call' in params
|
||||
transpose = params['transpose']
|
||||
|
||||
call_in_tree = treedef_tuple((res_tree, lin_tree))
|
||||
|
||||
# TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
|
||||
|
@ -805,7 +805,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
|
||||
|
||||
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
operand=_no_operand_sentinel):
|
||||
operand=_no_operand_sentinel, linear=None):
|
||||
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
||||
|
||||
``cond()`` has equivalent semantics to this Python implementation::
|
||||
@ -865,6 +865,12 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
return false_fun(*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
if linear is None:
|
||||
linear_ops = [False] * len(ops)
|
||||
else:
|
||||
linear_ops, ops_tree2 = tree_flatten(linear)
|
||||
if ops_tree != ops_tree2:
|
||||
raise TypeError('linear tree and operand tree mismatch')
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
@ -878,10 +884,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
|
||||
linear = (False,) * (len(consts) + len(ops))
|
||||
linear = [False] * len(consts) + linear_ops
|
||||
out = cond_p.bind(
|
||||
index, *consts, *ops,
|
||||
branches=(false_jaxpr, true_jaxpr), linear=linear)
|
||||
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
@api_boundary
|
||||
|
@ -164,8 +164,10 @@ def recast_to_float0(primal, tangent):
|
||||
else:
|
||||
return tangent
|
||||
|
||||
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
|
||||
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in):
|
||||
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
|
||||
# errors if you will)
|
||||
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
|
||||
consts, primals_in, cotangents_in):
|
||||
if all(type(ct) is Zero for ct in cotangents_in):
|
||||
return map(lambda v: Zero(v.aval), jaxpr.invars)
|
||||
|
||||
|
@ -1626,10 +1626,12 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
transpose_flat, in_tree2 = flatten_fun_nokwargs(
|
||||
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
|
||||
transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic(
|
||||
transpose_flat, self.main, in_avals_t)
|
||||
closed_transpose_jaxpr = core.ClosedJaxpr(
|
||||
convert_constvars_jaxpr(transpose_jaxpr), ())
|
||||
|
||||
main_ = ref(self.main)
|
||||
# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
|
||||
transpose_jaxpr_thunk = _memoize(
|
||||
lambda: trace_to_subjaxpr_dynamic(
|
||||
transpose_flat, main_(), in_avals_t)[::2])
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
@ -1637,9 +1639,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
|
||||
dict(call_jaxpr=closed_call_jaxpr,
|
||||
transpose_jaxpr=(closed_transpose_jaxpr,
|
||||
transpose_consts),
|
||||
num_consts=len(call_consts)),
|
||||
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
|
||||
out_types=out_types, res_tree=res_tree,
|
||||
lin_tree=lin_tree, out_tree=out_tree),
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
@ -6891,7 +6891,6 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
def test_jit_recursive(self):
|
||||
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
def fn(r, x): return x / r
|
||||
@ -6913,6 +6912,56 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_(x), g_(x))
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
def test_cond(self):
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return 2 * t / r
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
def cond_wrap(f):
|
||||
return lambda i, x: lax.cond(i > 0, f, lambda x: x, x,
|
||||
linear=(True,))
|
||||
|
||||
i = 7.
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
|
||||
f_ = lambda x: f(x, y)
|
||||
f_t = transpose_unary(f_, x)
|
||||
g_ = partial(cond_wrap(f_), i)
|
||||
g_t = transpose_unary(g_, x)
|
||||
|
||||
self.assertAllClose(f_(x), g_(x))
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
def test_cond_recursive(self):
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return 2 * fn(r, t)
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
def cond_wrap(f):
|
||||
return lambda i, x: lax.cond(i > 0, f, lambda x: x, x,
|
||||
linear=(True,))
|
||||
|
||||
i = 7.
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
|
||||
f_ = lambda x: f(x, y)
|
||||
f_t = transpose_unary(f_, x)
|
||||
g_ = partial(cond_wrap(f_), i)
|
||||
g_t = transpose_unary(g_, x)
|
||||
|
||||
self.assertAllClose(f_(x), g_(x))
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
|
||||
class CustomVmapTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user