From 45af307a618d38649bc1a6c51ee8d4c9e28f27b7 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 18 Feb 2022 13:44:06 -0800 Subject: [PATCH] staging and compilation for `custom_transpose` Co-authored-by: Matthew Johnson --- jax/_src/custom_transpose.py | 14 +++++++++ jax/interpreters/partial_eval.py | 40 ++++++++++++++++++++++-- tests/api_test.py | 53 +++++++++++++++++++++++++++++--- 3 files changed, 101 insertions(+), 6 deletions(-) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 9d60e623a..6abb1beac 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -18,6 +18,8 @@ from typing import Any, Callable, Optional, Tuple from jax import core from jax import linear_util as lu from jax.interpreters import ad +from jax.interpreters import mlir +from jax.interpreters import xla from jax.tree_util import (tree_flatten, tree_leaves, tree_map, tree_structure, treedef_tuple, tree_unflatten) from jax._src import ad_util @@ -187,6 +189,18 @@ def custom_transpose_transpose_rule( return [None] * len(tree_leaves(res_arg)) + ct_lin_flat +def custom_transpose_lowering(*args, call_jaxpr, **params): + return core.jaxpr_as_fun(call_jaxpr)(*args) + + custom_transpose_p = CustomTransposePrimitive('custom_transpose_call') core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule +mlir.register_lowering( + custom_transpose_p, + mlir.lower_fun(custom_transpose_lowering, multiple_results=True)) +xla.register_translation( + custom_transpose_p, + xla.lower_fun( + custom_transpose_lowering, new_style=True, multiple_results=True), + initial_style=True) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index c700ecb00..c2835cb82 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -30,8 +30,9 @@ from jax._src import dtypes from jax import linear_util as lu from jax._src import profiler from jax._src.ad_util import Zero -from jax._src.api_util import flattened_fun_in_tree -from jax._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves +from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, + tree_leaves) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, cache, OrderedSet, as_hashable_function) @@ -1610,6 +1611,41 @@ class DynamicJaxprTrace(core.Trace): def post_process_custom_vjp_call(self, out_tracers, _): assert False # unreachable + def process_custom_transpose(self, prim, call, tracers, + transpose, out_types, + lin_tree, res_tree, out_tree): + tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) + + in_avals_p = [t.aval for t in tracers] + in_avals_t = [*[t.aval for t in tracers_res], *out_types] + + with core.new_sublevel(): + call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic( + call, self.main, in_avals_p) + closed_call_jaxpr = core.ClosedJaxpr( + convert_constvars_jaxpr(call_jaxpr), ()) + + transpose_flat, in_tree2 = flatten_fun_nokwargs( + lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) + transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic( + transpose_flat, self.main, in_avals_t) + closed_transpose_jaxpr = core.ClosedJaxpr( + convert_constvars_jaxpr(transpose_jaxpr), ()) + + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + invars = map(self.getvar, tracers) + constvars = map(self.getvar, map(self.instantiate_const, call_consts)) + outvars = map(self.makevar, out_tracers) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, + dict(call_jaxpr=closed_call_jaxpr, + transpose_jaxpr=(closed_transpose_jaxpr, + transpose_consts), + num_consts=len(call_consts)), + source_info_util.current()) + self.frame.eqns.append(eqn) + return out_tracers + + custom_staging_rules: Dict[Primitive, Callable] = {} def _memoize(thunk): diff --git a/tests/api_test.py b/tests/api_test.py index 88d3d0721..99db51a50 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6778,14 +6778,32 @@ class CustomTransposeTest(jtu.JaxTestCase): self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), api.jvp(f_ref, [x, y], [tx, ty])) - def test_jit(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj) - + def test_make_jaxpr(self): def f(x, y): @custom_transpose(jnp.ones(2)) def fn(r, x): return x / r @fn.def_transpose - def tp(r, t): return t / r + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + + jaxpr = api.make_jaxpr(f_)(x) + self.assertIn('custom_transpose_call', str(jaxpr)) + + jaxpr_t = api.make_jaxpr(f_t)(x) + self.assertNotIn('custom_transpose_call', str(jaxpr_t)) + + def test_jit(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r return x + fn(y, x) @@ -6795,8 +6813,35 @@ class CustomTransposeTest(jtu.JaxTestCase): f_ = lambda x: f(x, y) f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) self.assertAllClose(f_(x), jax.jit(f_)(x)) self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_jit_recursive(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj) + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) class CustomVmapTest(jtu.JaxTestCase):