mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
staging and compilation for custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
5354a016e6
commit
45af307a61
@ -18,6 +18,8 @@ from typing import Any, Callable, Optional, Tuple
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
|
||||
tree_structure, treedef_tuple, tree_unflatten)
|
||||
from jax._src import ad_util
|
||||
@ -187,6 +189,18 @@ def custom_transpose_transpose_rule(
|
||||
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
|
||||
|
||||
|
||||
def custom_transpose_lowering(*args, call_jaxpr, **params):
|
||||
return core.jaxpr_as_fun(call_jaxpr)(*args)
|
||||
|
||||
|
||||
custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
|
||||
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
|
||||
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
|
||||
mlir.register_lowering(
|
||||
custom_transpose_p,
|
||||
mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
|
||||
xla.register_translation(
|
||||
custom_transpose_p,
|
||||
xla.lower_fun(
|
||||
custom_transpose_lowering, new_style=True, multiple_results=True),
|
||||
initial_style=True)
|
||||
|
@ -30,8 +30,9 @@ from jax._src import dtypes
|
||||
from jax import linear_util as lu
|
||||
from jax._src import profiler
|
||||
from jax._src.ad_util import Zero
|
||||
from jax._src.api_util import flattened_fun_in_tree
|
||||
from jax._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
|
||||
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
tree_leaves)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, cache, OrderedSet,
|
||||
as_hashable_function)
|
||||
@ -1610,6 +1611,41 @@ class DynamicJaxprTrace(core.Trace):
|
||||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||||
assert False # unreachable
|
||||
|
||||
def process_custom_transpose(self, prim, call, tracers,
|
||||
transpose, out_types,
|
||||
lin_tree, res_tree, out_tree):
|
||||
tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves])
|
||||
|
||||
in_avals_p = [t.aval for t in tracers]
|
||||
in_avals_t = [*[t.aval for t in tracers_res], *out_types]
|
||||
|
||||
with core.new_sublevel():
|
||||
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
|
||||
call, self.main, in_avals_p)
|
||||
closed_call_jaxpr = core.ClosedJaxpr(
|
||||
convert_constvars_jaxpr(call_jaxpr), ())
|
||||
|
||||
transpose_flat, in_tree2 = flatten_fun_nokwargs(
|
||||
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
|
||||
transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic(
|
||||
transpose_flat, self.main, in_avals_t)
|
||||
closed_transpose_jaxpr = core.ClosedJaxpr(
|
||||
convert_constvars_jaxpr(transpose_jaxpr), ())
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
|
||||
dict(call_jaxpr=closed_call_jaxpr,
|
||||
transpose_jaxpr=(closed_transpose_jaxpr,
|
||||
transpose_consts),
|
||||
num_consts=len(call_consts)),
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
|
||||
custom_staging_rules: Dict[Primitive, Callable] = {}
|
||||
|
||||
def _memoize(thunk):
|
||||
|
@ -6778,14 +6778,32 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(api.jvp(f, [x, y], [tx, ty]),
|
||||
api.jvp(f_ref, [x, y], [tx, ty]))
|
||||
|
||||
def test_jit(self):
|
||||
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
|
||||
|
||||
def test_make_jaxpr(self):
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return t / r
|
||||
def tp(r, t): return 2 * t / r
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
f_ = lambda x: f(x, y)
|
||||
f_t = transpose_unary(f_, x)
|
||||
|
||||
jaxpr = api.make_jaxpr(f_)(x)
|
||||
self.assertIn('custom_transpose_call', str(jaxpr))
|
||||
|
||||
jaxpr_t = api.make_jaxpr(f_t)(x)
|
||||
self.assertNotIn('custom_transpose_call', str(jaxpr_t))
|
||||
|
||||
def test_jit(self):
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return 2 * t / r
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
@ -6795,8 +6813,35 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
f_ = lambda x: f(x, y)
|
||||
f_t = transpose_unary(f_, x)
|
||||
g_ = jax.jit(f_)
|
||||
g_t = transpose_unary(g_, x)
|
||||
self.assertAllClose(f_(x), jax.jit(f_)(x))
|
||||
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
|
||||
self.assertAllClose(f_(x), g_(x))
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
def test_jit_recursive(self):
|
||||
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
def fn(r, x): return x / r
|
||||
@fn.def_transpose
|
||||
def tp(r, t): return 2 * fn(r, t)
|
||||
|
||||
return x + fn(y, x)
|
||||
|
||||
x = jnp.ones(2) * 6.
|
||||
y = jnp.ones(2) * 3.
|
||||
self.assertAllClose(f(x, y), jax.jit(f)(x, y))
|
||||
|
||||
f_ = lambda x: f(x, y)
|
||||
f_t = transpose_unary(f_, x)
|
||||
g_ = jax.jit(f_)
|
||||
g_t = transpose_unary(g_, x)
|
||||
self.assertAllClose(f_(x), jax.jit(f_)(x))
|
||||
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
|
||||
self.assertAllClose(f_(x), g_(x))
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
|
||||
class CustomVmapTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user