introduce custom_transpose and a corresponding primitive

Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
This commit is contained in:
Roy Frostig 2022-01-07 14:33:58 -08:00
parent 285c20388b
commit 1709e06800
5 changed files with 385 additions and 22 deletions

View File

@ -81,6 +81,7 @@ from jax.interpreters.invertible_ad import custom_ivjp
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.ad_checkpoint import checkpoint_policies
from jax._src.config import (flags, config, bool_env,

View File

@ -0,0 +1,134 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Callable, Optional
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_unflatten,
treedef_tuple)
from jax._src import ad_util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun_nokwargs
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
class custom_transpose:
fun: Callable
transpose: Optional[Callable]
def __init__(self, fun: Callable):
self.fun = fun # type: ignore[assignment]
self.transpose = None
functools.update_wrapper(self, fun)
def def_transpose(self, transpose: Callable):
self.transpose = transpose
return transpose
@traceback_util.api_boundary
def __call__(self, residual_arg, linear_arg):
res_arg, lin_arg = residual_arg, linear_arg
_, res_tree = tree_flatten(res_arg)
_, lin_tree = tree_flatten(lin_arg)
args_flat, in_tree = tree_flatten((res_arg, lin_arg))
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
out_flat = custom_transpose_p.bind(*consts, *args_flat,
call=closed_call,
rule=self.transpose,
lin_tree=lin_tree,
res_tree=res_tree,
out_tree=out_tree())
return tree_unflatten(out_tree(), out_flat)
### utils
def rule_name(rule):
return getattr(rule, '__name__', '<unnamed transpose rule>')
def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
if lin_tree != rule_out_tree and len(lin_tree.children()) == 1:
lin_tree2, = lin_tree.children()
else:
lin_tree2 = lin_tree
if lin_tree2 != rule_out_tree:
raise ValueError(
'structure of custom transpose rule\'s output does not match '
'structure of primal function\'s linear inputs under '
f'custom transpose rule ({rule_name(rule)}).\n'
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')
### custom_transpose_p rules
def custom_transpose_impl(*args, call, rule, res_tree, lin_tree, out_tree):
del rule, res_tree, lin_tree, out_tree
return core.jaxpr_as_fun(call)(*args)
def custom_transpose_transpose_rule(
cts, *args, call, rule, res_tree, lin_tree, out_tree):
call_in_tree = treedef_tuple((res_tree, lin_tree))
res_arg, lin_arg = tree_unflatten(call_in_tree, args)
assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg))
assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))
cts = [ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
for ct, ct_aval in zip(cts, call.out_avals)]
ct_out = tree_unflatten(out_tree, cts)
ct_lin = rule(res_arg, ct_out)
ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def custom_transpose_abstract_eval(*in_avals, call, **_):
return call.out_avals
custom_transpose_p = core.Primitive('custom_transpose_call')
custom_transpose_p.multiple_results = True
custom_transpose_p.def_impl(custom_transpose_impl)
custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval)
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
xla.register_translation(custom_transpose_p,
xla.lower_fun(custom_transpose_impl, new_style=True,
multiple_results=True),
initial_style=True)
mlir.register_lowering(custom_transpose_p, mlir.lower_fun(
custom_transpose_impl, multiple_results=True))

18
jax/custom_transpose.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from jax._src.custom_transpose import (
custom_transpose,
)

View File

@ -981,6 +981,7 @@ tf_not_yet_impl = [
"after_all",
"all_to_all",
"create_token",
"custom_transpose_call",
"infeed",
"linear_call",
"outfeed",

View File

@ -6200,13 +6200,14 @@ class CustomVJPTest(jtu.JaxTestCase):
modes=['rev'])
class CustomTransposeTest(jtu.JaxTestCase):
def transpose_unary(f, x_example):
def transposed(y):
x, = api.linear_transpose(f, x_example)(y)
return x
return transposed
def transpose(self, f, x_example):
def transposed(y):
x, = api.linear_transpose(f, x_example)(y)
return x
return transposed
class CustomTransposeTest(jtu.JaxTestCase):
def test_linear_call(self):
def f(x, y):
@ -6223,8 +6224,8 @@ class CustomTransposeTest(jtu.JaxTestCase):
f1 = lambda x: f(x, y)
f1_ref = lambda x: f_ref(x, y)
self.assertAllClose(self.transpose(f1, x)(x),
self.transpose(f1_ref, x)(x))
self.assertAllClose(transpose_unary(f1, x)(x),
transpose_unary(f1_ref, x)(x))
def test_linear_call_incorrect_transpose(self):
def f(x, y):
@ -6241,8 +6242,8 @@ class CustomTransposeTest(jtu.JaxTestCase):
f1 = lambda x: f(x, y)
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
self.assertAllClose(self.transpose(f1, x)(x),
self.transpose(f1_ref, x)(x))
self.assertAllClose(transpose_unary(f1, x)(x),
transpose_unary(f1_ref, x)(x))
def test_linear_call_transpose_transpose_transpose(self):
def fn(r, x): return x / r
@ -6253,9 +6254,9 @@ class CustomTransposeTest(jtu.JaxTestCase):
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f = lambda x: f_(x, y)
ft = self.transpose(f, x)
ftt = self.transpose(ft, x)
fttt = self.transpose(ftt, x)
ft = transpose_unary(f, x)
ftt = transpose_unary(ft, x)
fttt = transpose_unary(ftt, x)
self.assertAllClose(ft(x), x + tp(y, x))
self.assertAllClose(f(x), ftt(x))
self.assertAllClose(ft(x), fttt(x))
@ -6277,8 +6278,8 @@ class CustomTransposeTest(jtu.JaxTestCase):
c, x = 2., 3.
t = [4., 5.]
self.assertAllClose(f(c, x), f_ref(c, x))
self.assertAllClose(self.transpose(partial(f, c), x)(t),
self.transpose(partial(f_ref, c), x)(t))
self.assertAllClose(transpose_unary(partial(f, c), x)(t),
transpose_unary(partial(f_ref, c), x)(t))
def test_linear_call_nested(self):
# identity function with an untrue transpose of 0
@ -6296,11 +6297,11 @@ class CustomTransposeTest(jtu.JaxTestCase):
return api.linear_call(f_, t_, (), x)
x = 5.
id_t = self.transpose(id_, x)
id_tt = self.transpose(id_t, x)
ft = self.transpose(f, x)
ftt = self.transpose(ft, x)
fttt = self.transpose(ftt, x)
id_t = transpose_unary(id_, x)
id_tt = transpose_unary(id_t, x)
ft = transpose_unary(f, x)
ftt = transpose_unary(ft, x)
fttt = transpose_unary(ftt, x)
self.assertAllClose(id_(x), x)
self.assertAllClose(id_t(x), 0.)
@ -6322,8 +6323,216 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(f(x, y), jax.jit(f)(x, y))
f1 = lambda x: f(x, y)
self.assertAllClose(self.transpose(f1, x)(x),
jax.jit(self.transpose(f1, x))(x))
self.assertAllClose(transpose_unary(f1, x)(x),
jax.jit(transpose_unary(f1, x))(x))
def test_basic(self):
def f(x, y):
@api.custom_transpose
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return t / r
return x + fn(y, x)
def f_ref(x, y):
return x + x / y
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), f_ref(x, y))
f1 = lambda x: f(x, y)
f1_ref = lambda x: f_ref(x, y)
self.assertAllClose(transpose_unary(f1, x)(x),
transpose_unary(f1_ref, x)(x))
def test_incorrect_transpose(self):
def f(x, y):
@api.custom_transpose
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return t / (2. * r) # nb: not the true transpose
return x + fn(y, x)
def f_ref(x, y):
return x + x / y
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), f_ref(x, y))
f1 = lambda x: f(x, y)
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
self.assertAllClose(transpose_unary(f1, x)(x),
transpose_unary(f1_ref, x)(x))
def test_transpose_transpose_transpose(self):
@api.custom_transpose
def fn(r, x): return x / r
@api.custom_transpose
def tp(r, t): return t / (2. * r) # nb: untrue transpose
fn.def_transpose(tp)
tp.def_transpose(fn)
def f_(x, y):
return x + fn(y, x)
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f = lambda x: f_(x, y)
ft = transpose_unary(f, x)
ftt = transpose_unary(ft, x)
fttt = transpose_unary(ftt, x)
self.assertAllClose(ft(x), x + tp(y, x))
self.assertAllClose(f(x), ftt(x))
self.assertAllClose(ft(x), fttt(x))
def test_scalar_to_vector(self):
def f(c, x):
@api.custom_transpose
def fn(_, x):
return [x, x]
@fn.def_transpose
def tp(_, t):
t1, t2 = t
return t1 + t2
return fn((), c * x)
def f_ref(c, x):
return [c * x, c * x]
c, x = 2., 3.
t = [4., 5.]
self.assertAllClose(f(c, x), f_ref(c, x))
self.assertAllClose(transpose_unary(partial(f, c), x)(t),
transpose_unary(partial(f_ref, c), x)(t))
def test_nested(self):
# identity function with an untrue transpose of 0
def id_(x):
f = api.custom_transpose(lambda _, x: x)
t = api.custom_transpose(lambda _, t: 0.)
f.def_transpose(t)
t.def_transpose(f)
return f((), x)
# identity function with an untrue transpose of 7, and where both
# forward and transpose have custom transpositions that should
# never end up invoked.
def f(x):
f_ = api.custom_transpose(lambda _, x: id_(x))
t_ = api.custom_transpose(lambda _, t: id_(7.))
f_.def_transpose(t_)
t_.def_transpose(f_)
return f_((), x)
x = 5.
id_t = transpose_unary(id_, x)
id_tt = transpose_unary(id_t, x)
ft = transpose_unary(f, x)
ftt = transpose_unary(ft, x)
fttt = transpose_unary(ftt, x)
self.assertAllClose(id_(x), x)
self.assertAllClose(id_t(x), 0.)
self.assertAllClose(id_tt(x), x)
self.assertAllClose(f(x), x)
self.assertAllClose(ft(x), 7.)
self.assertAllClose(ftt(x), x)
self.assertAllClose(fttt(x), 7.)
def test_one_degree(self):
T = lambda f: transpose_unary(f, 0.)
@api.custom_transpose
def f(_, z): return 2. * z
@f.def_transpose
def ft(_, z): return 3. * z
f = partial(f, ())
self.assertAllClose(2., f(1.))
self.assertAllClose(3., T(f)(1.))
self.assertAllClose(3., T(T(f))(1.))
self.assertAllClose(3., T(T(T(f)))(1.))
self.assertAllClose(3., T(T(T(T(f))))(1.)) # ...
def test_two_degrees(self):
T = lambda f: transpose_unary(f, 0.)
@api.custom_transpose
def f(_, z): return 2. * z
@f.def_transpose
@api.custom_transpose
def ft(_, z): return 3. * z
@ft.def_transpose
def ftt(_, z): return 7. * z
f = partial(f, ())
self.assertAllClose(2., f(1.))
self.assertAllClose(3., T(f)(1.))
self.assertAllClose(7., T(T(f))(1.))
self.assertAllClose(7., T(T(T(f)))(1.))
self.assertAllClose(7., T(T(T(T(f))))(1.)) # ...
def test_symmetric(self):
T = lambda f: transpose_unary(f, 0.)
@api.custom_transpose
def f(_, z): return 2. * z
@api.custom_transpose
def g(_, z): return 3. * z
f.def_transpose(g)
g.def_transpose(f)
f = partial(f, ())
self.assertAllClose(2., f(1.))
self.assertAllClose(3., T(f)(1.))
self.assertAllClose(2., T(T(f))(1.))
self.assertAllClose(3., T(T(T(f)))(1.))
self.assertAllClose(2., T(T(T(T(f))))(1.)) # ...
def test_recursive(self):
T = lambda f: transpose_unary(f, 0.)
@api.custom_transpose
def f(c, z): return c * z
@f.def_transpose
def ft(c, z): return f(c + 1., z)
g = partial(f, 1.)
self.assertAllClose(1., g(1.))
self.assertAllClose(2., T(g)(1.))
self.assertAllClose(3., T(T(g))(1.))
self.assertAllClose(4., T(T(T(g)))(1.))
self.assertAllClose(5., T(T(T(T(g))))(1.)) # ...
def test_jit(self):
def f(x, y):
@api.custom_transpose
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return t / r
return x + fn(y, x)
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), jax.jit(f)(x, y))
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
self.assertAllClose(f_(x), jax.jit(f_)(x))
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
class CustomVmapTest(jtu.JaxTestCase):