staging and compilation for custom_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Roy Frostig 2022-02-18 13:44:06 -08:00
parent 5354a016e6
commit 45af307a61
3 changed files with 101 additions and 6 deletions

View File

@ -18,6 +18,8 @@ from typing import Any, Callable, Optional, Tuple
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
@ -187,6 +189,18 @@ def custom_transpose_transpose_rule(
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def custom_transpose_lowering(*args, call_jaxpr, **params):
return core.jaxpr_as_fun(call_jaxpr)(*args)
custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
custom_transpose_p,
mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_translation(
custom_transpose_p,
xla.lower_fun(
custom_transpose_lowering, new_style=True, multiple_results=True),
initial_style=True)

View File

@ -30,8 +30,9 @@ from jax._src import dtypes
from jax import linear_util as lu
from jax._src import profiler
from jax._src.ad_util import Zero
from jax._src.api_util import flattened_fun_in_tree
from jax._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, cache, OrderedSet,
as_hashable_function)
@ -1610,6 +1611,41 @@ class DynamicJaxprTrace(core.Trace):
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable
def process_custom_transpose(self, prim, call, tracers,
transpose, out_types,
lin_tree, res_tree, out_tree):
tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves])
in_avals_p = [t.aval for t in tracers]
in_avals_t = [*[t.aval for t in tracers_res], *out_types]
with core.new_sublevel():
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
call, self.main, in_avals_p)
closed_call_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(call_jaxpr), ())
transpose_flat, in_tree2 = flatten_fun_nokwargs(
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic(
transpose_flat, self.main, in_avals_t)
closed_transpose_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(transpose_jaxpr), ())
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_call_jaxpr,
transpose_jaxpr=(closed_transpose_jaxpr,
transpose_consts),
num_consts=len(call_consts)),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
custom_staging_rules: Dict[Primitive, Callable] = {}
def _memoize(thunk):

View File

@ -6778,14 +6778,32 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(api.jvp(f, [x, y], [tx, ty]),
api.jvp(f_ref, [x, y], [tx, ty]))
def test_jit(self):
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
def test_make_jaxpr(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return t / r
def tp(r, t): return 2 * t / r
return x + fn(y, x)
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
jaxpr = api.make_jaxpr(f_)(x)
self.assertIn('custom_transpose_call', str(jaxpr))
jaxpr_t = api.make_jaxpr(f_t)(x)
self.assertNotIn('custom_transpose_call', str(jaxpr_t))
def test_jit(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * t / r
return x + fn(y, x)
@ -6795,8 +6813,35 @@ class CustomTransposeTest(jtu.JaxTestCase):
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = jax.jit(f_)
g_t = transpose_unary(g_, x)
self.assertAllClose(f_(x), jax.jit(f_)(x))
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))
def test_jit_recursive(self):
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * fn(r, t)
return x + fn(y, x)
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), jax.jit(f)(x, y))
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = jax.jit(f_)
g_t = transpose_unary(g_, x)
self.assertAllClose(f_(x), jax.jit(f_)(x))
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))
class CustomVmapTest(jtu.JaxTestCase):