mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir translation.
This commit is contained in:
parent
285c20388b
commit
1709e06800
@ -81,6 +81,7 @@ from jax.interpreters.invertible_ad import custom_ivjp
|
||||
from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
custom_vjp, linear_call)
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.ad_checkpoint import checkpoint_policies
|
||||
|
||||
from jax._src.config import (flags, config, bool_env,
|
||||
|
134
jax/_src/custom_transpose.py
Normal file
134
jax/_src/custom_transpose.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_leaves, tree_unflatten,
|
||||
treedef_tuple)
|
||||
from jax._src import ad_util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
||||
class custom_transpose:
|
||||
fun: Callable
|
||||
transpose: Optional[Callable]
|
||||
|
||||
def __init__(self, fun: Callable):
|
||||
self.fun = fun # type: ignore[assignment]
|
||||
self.transpose = None
|
||||
functools.update_wrapper(self, fun)
|
||||
|
||||
def def_transpose(self, transpose: Callable):
|
||||
self.transpose = transpose
|
||||
return transpose
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def __call__(self, residual_arg, linear_arg):
|
||||
res_arg, lin_arg = residual_arg, linear_arg
|
||||
_, res_tree = tree_flatten(res_arg)
|
||||
_, lin_tree = tree_flatten(lin_arg)
|
||||
args_flat, in_tree = tree_flatten((res_arg, lin_arg))
|
||||
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
out_flat = custom_transpose_p.bind(*consts, *args_flat,
|
||||
call=closed_call,
|
||||
rule=self.transpose,
|
||||
lin_tree=lin_tree,
|
||||
res_tree=res_tree,
|
||||
out_tree=out_tree())
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
|
||||
### utils
|
||||
|
||||
def rule_name(rule):
|
||||
return getattr(rule, '__name__', '<unnamed transpose rule>')
|
||||
|
||||
def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
|
||||
if lin_tree != rule_out_tree and len(lin_tree.children()) == 1:
|
||||
lin_tree2, = lin_tree.children()
|
||||
else:
|
||||
lin_tree2 = lin_tree
|
||||
if lin_tree2 != rule_out_tree:
|
||||
raise ValueError(
|
||||
'structure of custom transpose rule\'s output does not match '
|
||||
'structure of primal function\'s linear inputs under '
|
||||
f'custom transpose rule ({rule_name(rule)}).\n'
|
||||
f'Transpose rule output: {rule_out_tree}\n'
|
||||
f'Linear primal inputs: {lin_tree}')
|
||||
|
||||
|
||||
### custom_transpose_p rules
|
||||
|
||||
|
||||
def custom_transpose_impl(*args, call, rule, res_tree, lin_tree, out_tree):
|
||||
del rule, res_tree, lin_tree, out_tree
|
||||
return core.jaxpr_as_fun(call)(*args)
|
||||
|
||||
|
||||
def custom_transpose_transpose_rule(
|
||||
cts, *args, call, rule, res_tree, lin_tree, out_tree):
|
||||
call_in_tree = treedef_tuple((res_tree, lin_tree))
|
||||
|
||||
res_arg, lin_arg = tree_unflatten(call_in_tree, args)
|
||||
assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg))
|
||||
assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))
|
||||
|
||||
cts = [ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
|
||||
for ct, ct_aval in zip(cts, call.out_avals)]
|
||||
ct_out = tree_unflatten(out_tree, cts)
|
||||
ct_lin = rule(res_arg, ct_out)
|
||||
ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
|
||||
check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
|
||||
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
|
||||
|
||||
|
||||
def custom_transpose_abstract_eval(*in_avals, call, **_):
|
||||
return call.out_avals
|
||||
|
||||
|
||||
custom_transpose_p = core.Primitive('custom_transpose_call')
|
||||
custom_transpose_p.multiple_results = True
|
||||
custom_transpose_p.def_impl(custom_transpose_impl)
|
||||
custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval)
|
||||
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
|
||||
xla.register_translation(custom_transpose_p,
|
||||
xla.lower_fun(custom_transpose_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
mlir.register_lowering(custom_transpose_p, mlir.lower_fun(
|
||||
custom_transpose_impl, multiple_results=True))
|
18
jax/custom_transpose.py
Normal file
18
jax/custom_transpose.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.custom_transpose import (
|
||||
custom_transpose,
|
||||
)
|
@ -981,6 +981,7 @@ tf_not_yet_impl = [
|
||||
"after_all",
|
||||
"all_to_all",
|
||||
"create_token",
|
||||
"custom_transpose_call",
|
||||
"infeed",
|
||||
"linear_call",
|
||||
"outfeed",
|
||||
|
@ -6200,13 +6200,14 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
modes=['rev'])
|
||||
|
||||
|
||||
class CustomTransposeTest(jtu.JaxTestCase):
|
||||
def transpose_unary(f, x_example):
|
||||
def transposed(y):
|
||||
x, = api.linear_transpose(f, x_example)(y)
|
||||
return x
|
||||
return transposed
|
||||
|
||||
def transpose(self, f, x_example):
|
||||
def transposed(y):
|
||||
x, = api.linear_transpose(f, x_example)(y)
|
||||
return x
|
||||
return transposed
|
||||
|
||||
class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_linear_call(self):
|
||||
def f(x, y):
|
||||
@ -6223,8 +6224,8 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
f1_ref = lambda x: f_ref(x, y)
|
||||
self.assertAllClose(self.transpose(f1, x)(x),
|
||||
self.transpose(f1_ref, x)(x))
|
||||
self.assertAllClose(transpose_unary(f1, x)(x),
|
||||
transpose_unary(f1_ref, x)(x))
|
||||
|
||||
def test_linear_call_incorrect_transpose(self):
|
||||
def f(x, y):
|
||||
@ -6241,8 +6242,8 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
|
||||
self.assertAllClose(self.transpose(f1, x)(x),
|
||||
self.transpose(f1_ref, x)(x))
|
||||
self.assertAllClose(transpose_unary(f1, x)(x),
|
||||
transpose_unary(f1_ref, x)(x))
|
||||
|
||||
def test_linear_call_transpose_transpose_transpose(self):
|
||||
def fn(r, x): return x / r
|
||||
@ -6253,9 +6254,9 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
f = lambda x: f_(x, y)
|
||||
ft = self.transpose(f, x)
|
||||
ftt = self.transpose(ft, x)
|
||||
fttt = self.transpose(ftt, x)
|
||||
ft = transpose_unary(f, x)
|
||||
ftt = transpose_unary(ft, x)
|
||||
fttt = transpose_unary(ftt, x)
|
||||
self.assertAllClose(ft(x), x + tp(y, x))
|
||||
self.assertAllClose(f(x), ftt(x))
|
||||
self.assertAllClose(ft(x), fttt(x))
|
||||
@ -6277,8 +6278,8 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
c, x = 2., 3.
|
||||
t = [4., 5.]
|
||||
self.assertAllClose(f(c, x), f_ref(c, x))
|
||||
self.assertAllClose(self.transpose(partial(f, c), x)(t),
|
||||
self.transpose(partial(f_ref, c), x)(t))
|
||||
self.assertAllClose(transpose_unary(partial(f, c), x)(t),
|
||||
transpose_unary(partial(f_ref, c), x)(t))
|
||||
|
||||
def test_linear_call_nested(self):
|
||||
# identity function with an untrue transpose of 0
|
||||
@ -6296,11 +6297,11 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
return api.linear_call(f_, t_, (), x)
|
||||
|
||||
x = 5.
|
||||
id_t = self.transpose(id_, x)
|
||||
id_tt = self.transpose(id_t, x)
|
||||
ft = self.transpose(f, x)
|
||||
ftt = self.transpose(ft, x)
|
||||
fttt = self.transpose(ftt, x)
|
||||
id_t = transpose_unary(id_, x)
|
||||
id_tt = transpose_unary(id_t, x)
|
||||
ft = transpose_unary(f, x)
|
||||
ftt = transpose_unary(ft, x)
|
||||
fttt = transpose_unary(ftt, x)
|
||||
|
||||
self.assertAllClose(id_(x), x)
|
||||
self.assertAllClose(id_t(x), 0.)
|
||||
@ -6322,8 +6323,216 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(x, y), jax.jit(f)(x, y))
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
self.assertAllClose(self.transpose(f1, x)(x),
|
||||
jax.jit(self.transpose(f1, x))(x))
|
||||
self.assertAllClose(transpose_unary(f1, x)(x),
|
||||
jax.jit(transpose_unary(f1, x))(x))
|
||||
|
||||
def test_basic(self):
|
||||
def f(x, y):
|
||||
@api.custom_transpose
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return t / r
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
def f_ref(x, y):
|
||||
return x + x / y
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
self.assertAllClose(f(x, y), f_ref(x, y))
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
f1_ref = lambda x: f_ref(x, y)
|
||||
self.assertAllClose(transpose_unary(f1, x)(x),
|
||||
transpose_unary(f1_ref, x)(x))
|
||||
|
||||
def test_incorrect_transpose(self):
|
||||
def f(x, y):
|
||||
@api.custom_transpose
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return t / (2. * r) # nb: not the true transpose
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
def f_ref(x, y):
|
||||
return x + x / y
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
self.assertAllClose(f(x, y), f_ref(x, y))
|
||||
|
||||
f1 = lambda x: f(x, y)
|
||||
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
|
||||
self.assertAllClose(transpose_unary(f1, x)(x),
|
||||
transpose_unary(f1_ref, x)(x))
|
||||
|
||||
def test_transpose_transpose_transpose(self):
|
||||
@api.custom_transpose
|
||||
def fn(r, x): return x / r
|
||||
@api.custom_transpose
|
||||
def tp(r, t): return t / (2. * r) # nb: untrue transpose
|
||||
|
||||
fn.def_transpose(tp)
|
||||
tp.def_transpose(fn)
|
||||
|
||||
def f_(x, y):
|
||||
return x + fn(y, x)
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
f = lambda x: f_(x, y)
|
||||
ft = transpose_unary(f, x)
|
||||
ftt = transpose_unary(ft, x)
|
||||
fttt = transpose_unary(ftt, x)
|
||||
self.assertAllClose(ft(x), x + tp(y, x))
|
||||
self.assertAllClose(f(x), ftt(x))
|
||||
self.assertAllClose(ft(x), fttt(x))
|
||||
|
||||
def test_scalar_to_vector(self):
|
||||
def f(c, x):
|
||||
@api.custom_transpose
|
||||
def fn(_, x):
|
||||
return [x, x]
|
||||
|
||||
@fn.def_transpose
|
||||
def tp(_, t):
|
||||
t1, t2 = t
|
||||
return t1 + t2
|
||||
|
||||
return fn((), c * x)
|
||||
|
||||
def f_ref(c, x):
|
||||
return [c * x, c * x]
|
||||
|
||||
c, x = 2., 3.
|
||||
t = [4., 5.]
|
||||
self.assertAllClose(f(c, x), f_ref(c, x))
|
||||
self.assertAllClose(transpose_unary(partial(f, c), x)(t),
|
||||
transpose_unary(partial(f_ref, c), x)(t))
|
||||
|
||||
def test_nested(self):
|
||||
# identity function with an untrue transpose of 0
|
||||
def id_(x):
|
||||
f = api.custom_transpose(lambda _, x: x)
|
||||
t = api.custom_transpose(lambda _, t: 0.)
|
||||
f.def_transpose(t)
|
||||
t.def_transpose(f)
|
||||
return f((), x)
|
||||
|
||||
# identity function with an untrue transpose of 7, and where both
|
||||
# forward and transpose have custom transpositions that should
|
||||
# never end up invoked.
|
||||
def f(x):
|
||||
f_ = api.custom_transpose(lambda _, x: id_(x))
|
||||
t_ = api.custom_transpose(lambda _, t: id_(7.))
|
||||
f_.def_transpose(t_)
|
||||
t_.def_transpose(f_)
|
||||
return f_((), x)
|
||||
|
||||
x = 5.
|
||||
id_t = transpose_unary(id_, x)
|
||||
id_tt = transpose_unary(id_t, x)
|
||||
ft = transpose_unary(f, x)
|
||||
ftt = transpose_unary(ft, x)
|
||||
fttt = transpose_unary(ftt, x)
|
||||
|
||||
self.assertAllClose(id_(x), x)
|
||||
self.assertAllClose(id_t(x), 0.)
|
||||
self.assertAllClose(id_tt(x), x)
|
||||
|
||||
self.assertAllClose(f(x), x)
|
||||
self.assertAllClose(ft(x), 7.)
|
||||
self.assertAllClose(ftt(x), x)
|
||||
self.assertAllClose(fttt(x), 7.)
|
||||
|
||||
def test_one_degree(self):
|
||||
T = lambda f: transpose_unary(f, 0.)
|
||||
|
||||
@api.custom_transpose
|
||||
def f(_, z): return 2. * z
|
||||
@f.def_transpose
|
||||
def ft(_, z): return 3. * z
|
||||
|
||||
f = partial(f, ())
|
||||
self.assertAllClose(2., f(1.))
|
||||
self.assertAllClose(3., T(f)(1.))
|
||||
self.assertAllClose(3., T(T(f))(1.))
|
||||
self.assertAllClose(3., T(T(T(f)))(1.))
|
||||
self.assertAllClose(3., T(T(T(T(f))))(1.)) # ...
|
||||
|
||||
def test_two_degrees(self):
|
||||
T = lambda f: transpose_unary(f, 0.)
|
||||
|
||||
@api.custom_transpose
|
||||
def f(_, z): return 2. * z
|
||||
|
||||
@f.def_transpose
|
||||
@api.custom_transpose
|
||||
def ft(_, z): return 3. * z
|
||||
|
||||
@ft.def_transpose
|
||||
def ftt(_, z): return 7. * z
|
||||
|
||||
f = partial(f, ())
|
||||
self.assertAllClose(2., f(1.))
|
||||
self.assertAllClose(3., T(f)(1.))
|
||||
self.assertAllClose(7., T(T(f))(1.))
|
||||
self.assertAllClose(7., T(T(T(f)))(1.))
|
||||
self.assertAllClose(7., T(T(T(T(f))))(1.)) # ...
|
||||
|
||||
def test_symmetric(self):
|
||||
T = lambda f: transpose_unary(f, 0.)
|
||||
|
||||
@api.custom_transpose
|
||||
def f(_, z): return 2. * z
|
||||
@api.custom_transpose
|
||||
def g(_, z): return 3. * z
|
||||
|
||||
f.def_transpose(g)
|
||||
g.def_transpose(f)
|
||||
|
||||
f = partial(f, ())
|
||||
self.assertAllClose(2., f(1.))
|
||||
self.assertAllClose(3., T(f)(1.))
|
||||
self.assertAllClose(2., T(T(f))(1.))
|
||||
self.assertAllClose(3., T(T(T(f)))(1.))
|
||||
self.assertAllClose(2., T(T(T(T(f))))(1.)) # ...
|
||||
|
||||
def test_recursive(self):
|
||||
T = lambda f: transpose_unary(f, 0.)
|
||||
|
||||
@api.custom_transpose
|
||||
def f(c, z): return c * z
|
||||
|
||||
@f.def_transpose
|
||||
def ft(c, z): return f(c + 1., z)
|
||||
|
||||
g = partial(f, 1.)
|
||||
self.assertAllClose(1., g(1.))
|
||||
self.assertAllClose(2., T(g)(1.))
|
||||
self.assertAllClose(3., T(T(g))(1.))
|
||||
self.assertAllClose(4., T(T(T(g)))(1.))
|
||||
self.assertAllClose(5., T(T(T(T(g))))(1.)) # ...
|
||||
|
||||
def test_jit(self):
|
||||
def f(x, y):
|
||||
@api.custom_transpose
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return t / r
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
self.assertAllClose(f(x, y), jax.jit(f)(x, y))
|
||||
|
||||
f_ = lambda x: f(x, y)
|
||||
f_t = transpose_unary(f_, x)
|
||||
self.assertAllClose(f_(x), jax.jit(f_)(x))
|
||||
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
|
||||
|
||||
|
||||
class CustomVmapTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user