diff --git a/jax/_src/api.py b/jax/_src/api.py index c01bd14ca..0a9901b96 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -81,6 +81,7 @@ from jax.interpreters.invertible_ad import custom_ivjp from jax.custom_batching import custom_vmap from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp, custom_vjp, linear_call) +from jax.custom_transpose import custom_transpose from jax.ad_checkpoint import checkpoint_policies from jax._src.config import (flags, config, bool_env, diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py new file mode 100644 index 000000000..316cfde83 --- /dev/null +++ b/jax/_src/custom_transpose.py @@ -0,0 +1,134 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Callable, Optional + +from jax import core +from jax import linear_util as lu +from jax.interpreters import ad +from jax.interpreters import partial_eval as pe +from jax.interpreters import mlir +from jax.interpreters import xla +from jax.tree_util import (tree_flatten, tree_leaves, tree_unflatten, + treedef_tuple) +from jax._src import ad_util +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.api_util import flatten_fun_nokwargs + + +source_info_util.register_exclusion(__file__) +traceback_util.register_exclusion(__file__) + + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip + + +class custom_transpose: + fun: Callable + transpose: Optional[Callable] + + def __init__(self, fun: Callable): + self.fun = fun # type: ignore[assignment] + self.transpose = None + functools.update_wrapper(self, fun) + + def def_transpose(self, transpose: Callable): + self.transpose = transpose + return transpose + + @traceback_util.api_boundary + def __call__(self, residual_arg, linear_arg): + res_arg, lin_arg = residual_arg, linear_arg + _, res_tree = tree_flatten(res_arg) + _, lin_tree = tree_flatten(lin_arg) + args_flat, in_tree = tree_flatten((res_arg, lin_arg)) + + flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose") + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + assert not len(consts) + closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) + out_flat = custom_transpose_p.bind(*consts, *args_flat, + call=closed_call, + rule=self.transpose, + lin_tree=lin_tree, + res_tree=res_tree, + out_tree=out_tree()) + return tree_unflatten(out_tree(), out_flat) + + +### utils + +def rule_name(rule): + return getattr(rule, '__name__', '') + +def check_transpose_rule_trees(rule, lin_tree, rule_out_tree): + if lin_tree != rule_out_tree and len(lin_tree.children()) == 1: + lin_tree2, = lin_tree.children() + else: + lin_tree2 = lin_tree + if lin_tree2 != rule_out_tree: + raise ValueError( + 'structure of custom transpose rule\'s output does not match ' + 'structure of primal function\'s linear inputs under ' + f'custom transpose rule ({rule_name(rule)}).\n' + f'Transpose rule output: {rule_out_tree}\n' + f'Linear primal inputs: {lin_tree}') + + +### custom_transpose_p rules + + +def custom_transpose_impl(*args, call, rule, res_tree, lin_tree, out_tree): + del rule, res_tree, lin_tree, out_tree + return core.jaxpr_as_fun(call)(*args) + + +def custom_transpose_transpose_rule( + cts, *args, call, rule, res_tree, lin_tree, out_tree): + call_in_tree = treedef_tuple((res_tree, lin_tree)) + + res_arg, lin_arg = tree_unflatten(call_in_tree, args) + assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg)) + assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) + + cts = [ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct + for ct, ct_aval in zip(cts, call.out_avals)] + ct_out = tree_unflatten(out_tree, cts) + ct_lin = rule(res_arg, ct_out) + ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin) + check_transpose_rule_trees(rule, lin_tree, ct_lin_tree) + return [None] * len(tree_leaves(res_arg)) + ct_lin_flat + + +def custom_transpose_abstract_eval(*in_avals, call, **_): + return call.out_avals + + +custom_transpose_p = core.Primitive('custom_transpose_call') +custom_transpose_p.multiple_results = True +custom_transpose_p.def_impl(custom_transpose_impl) +custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval) +ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule +xla.register_translation(custom_transpose_p, + xla.lower_fun(custom_transpose_impl, new_style=True, + multiple_results=True), + initial_style=True) +mlir.register_lowering(custom_transpose_p, mlir.lower_fun( + custom_transpose_impl, multiple_results=True)) diff --git a/jax/custom_transpose.py b/jax/custom_transpose.py new file mode 100644 index 000000000..1e00577c1 --- /dev/null +++ b/jax/custom_transpose.py @@ -0,0 +1,18 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: F401 +from jax._src.custom_transpose import ( + custom_transpose, +) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index f7956cc9e..c8ecde3f6 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -981,6 +981,7 @@ tf_not_yet_impl = [ "after_all", "all_to_all", "create_token", + "custom_transpose_call", "infeed", "linear_call", "outfeed", diff --git a/tests/api_test.py b/tests/api_test.py index 042c9e2f1..d0c5613d3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6200,13 +6200,14 @@ class CustomVJPTest(jtu.JaxTestCase): modes=['rev']) -class CustomTransposeTest(jtu.JaxTestCase): +def transpose_unary(f, x_example): + def transposed(y): + x, = api.linear_transpose(f, x_example)(y) + return x + return transposed - def transpose(self, f, x_example): - def transposed(y): - x, = api.linear_transpose(f, x_example)(y) - return x - return transposed + +class CustomTransposeTest(jtu.JaxTestCase): def test_linear_call(self): def f(x, y): @@ -6223,8 +6224,8 @@ class CustomTransposeTest(jtu.JaxTestCase): f1 = lambda x: f(x, y) f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(self.transpose(f1, x)(x), - self.transpose(f1_ref, x)(x)) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) def test_linear_call_incorrect_transpose(self): def f(x, y): @@ -6241,8 +6242,8 @@ class CustomTransposeTest(jtu.JaxTestCase): f1 = lambda x: f(x, y) f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(self.transpose(f1, x)(x), - self.transpose(f1_ref, x)(x)) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) def test_linear_call_transpose_transpose_transpose(self): def fn(r, x): return x / r @@ -6253,9 +6254,9 @@ class CustomTransposeTest(jtu.JaxTestCase): x = jnp.ones(2) * 6. y = jnp.ones(2) * 3. f = lambda x: f_(x, y) - ft = self.transpose(f, x) - ftt = self.transpose(ft, x) - fttt = self.transpose(ftt, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) self.assertAllClose(ft(x), x + tp(y, x)) self.assertAllClose(f(x), ftt(x)) self.assertAllClose(ft(x), fttt(x)) @@ -6277,8 +6278,8 @@ class CustomTransposeTest(jtu.JaxTestCase): c, x = 2., 3. t = [4., 5.] self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(self.transpose(partial(f, c), x)(t), - self.transpose(partial(f_ref, c), x)(t)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) def test_linear_call_nested(self): # identity function with an untrue transpose of 0 @@ -6296,11 +6297,11 @@ class CustomTransposeTest(jtu.JaxTestCase): return api.linear_call(f_, t_, (), x) x = 5. - id_t = self.transpose(id_, x) - id_tt = self.transpose(id_t, x) - ft = self.transpose(f, x) - ftt = self.transpose(ft, x) - fttt = self.transpose(ftt, x) + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) self.assertAllClose(id_(x), x) self.assertAllClose(id_t(x), 0.) @@ -6322,8 +6323,216 @@ class CustomTransposeTest(jtu.JaxTestCase): self.assertAllClose(f(x, y), jax.jit(f)(x, y)) f1 = lambda x: f(x, y) - self.assertAllClose(self.transpose(f1, x)(x), - jax.jit(self.transpose(f1, x))(x)) + self.assertAllClose(transpose_unary(f1, x)(x), + jax.jit(transpose_unary(f1, x))(x)) + + def test_basic(self): + def f(x, y): + @api.custom_transpose + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_incorrect_transpose(self): + def f(x, y): + @api.custom_transpose + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / (2. * r) # nb: not the true transpose + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_transpose_transpose_transpose(self): + @api.custom_transpose + def fn(r, x): return x / r + @api.custom_transpose + def tp(r, t): return t / (2. * r) # nb: untrue transpose + + fn.def_transpose(tp) + tp.def_transpose(fn) + + def f_(x, y): + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_scalar_to_vector(self): + def f(c, x): + @api.custom_transpose + def fn(_, x): + return [x, x] + + @fn.def_transpose + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return fn((), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + f = api.custom_transpose(lambda _, x: x) + t = api.custom_transpose(lambda _, t: 0.) + f.def_transpose(t) + t.def_transpose(f) + return f((), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + f_ = api.custom_transpose(lambda _, x: id_(x)) + t_ = api.custom_transpose(lambda _, t: id_(7.)) + f_.def_transpose(t_) + t_.def_transpose(f_) + return f_((), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_one_degree(self): + T = lambda f: transpose_unary(f, 0.) + + @api.custom_transpose + def f(_, z): return 2. * z + @f.def_transpose + def ft(_, z): return 3. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(3., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... + + def test_two_degrees(self): + T = lambda f: transpose_unary(f, 0.) + + @api.custom_transpose + def f(_, z): return 2. * z + + @f.def_transpose + @api.custom_transpose + def ft(_, z): return 3. * z + + @ft.def_transpose + def ftt(_, z): return 7. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(7., T(T(f))(1.)) + self.assertAllClose(7., T(T(T(f)))(1.)) + self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... + + def test_symmetric(self): + T = lambda f: transpose_unary(f, 0.) + + @api.custom_transpose + def f(_, z): return 2. * z + @api.custom_transpose + def g(_, z): return 3. * z + + f.def_transpose(g) + g.def_transpose(f) + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(2., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... + + def test_recursive(self): + T = lambda f: transpose_unary(f, 0.) + + @api.custom_transpose + def f(c, z): return c * z + + @f.def_transpose + def ft(c, z): return f(c + 1., z) + + g = partial(f, 1.) + self.assertAllClose(1., g(1.)) + self.assertAllClose(2., T(g)(1.)) + self.assertAllClose(3., T(T(g))(1.)) + self.assertAllClose(4., T(T(T(g)))(1.)) + self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... + + def test_jit(self): + def f(x, y): + @api.custom_transpose + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) class CustomVmapTest(jtu.JaxTestCase):