ahead-of-time lowering and compilation frontend for xmap

This commit is contained in:
Roy Frostig 2021-11-23 18:22:12 -08:00
parent d81114deff
commit e53ec025c7
2 changed files with 33 additions and 5 deletions

View File

@ -26,7 +26,7 @@ from enum import Enum
from jax import numpy as jnp
from jax import core
from jax import linear_util as lu
from jax._src.api import _check_callable, _check_arg
from jax._src.api import Lowered, _check_callable, _check_arg
from jax._src import dispatch
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
tree_leaves)
@ -639,7 +639,7 @@ def xmap(fun: Callable,
spmd_in_axes=None,
spmd_out_axes_thunk=None,
positional_semantics=_positional_semantics)
return fun_flat, args_flat, params, out_tree
return fun_flat, args_flat, params, in_tree, out_tree
def verify_outputs(out_flat, out_tree, params):
if has_output_rank_assertions:
@ -652,7 +652,7 @@ def xmap(fun: Callable,
def fun_mapped(*args):
tree_map(_check_arg, args)
fun_flat, args_flat, params, out_tree = infer_params(*args)
fun_flat, args_flat, params, _, out_tree = infer_params(*args)
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
return verify_outputs(out_flat, out_tree, params)
@ -662,13 +662,15 @@ def xmap(fun: Callable,
return f
def lower(*args):
fun_flat, args_flat, params, out_tree = infer_params(*args)
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
return make_xmap_callable(
computation = make_xmap_callable(
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
params['spmd_out_axes_thunk'], params['positional_semantics'], *avals_flat)
return Lowered(
computation, in_tree, out_tree(), donate_argnums, no_kwargs=True)
fun_mapped = wraps(fun)(decorate_serial(fun_mapped))
fun_mapped.lower = decorate_serial(lower)

View File

@ -591,6 +591,32 @@ class XMapTest(XMapTestCase):
# Make sure this doesn't crash
xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]).lower(x)
def testLowerCompile(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
f_exe = f.lower(x).compile()
self.assertAllClose(f_exe(x), f(x))
def testLowerCompileInTreeMismatch(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
f_exe = f.lower(x).compile()
self.assertRaisesRegex(
TypeError, "function compiled for .*, called with .*",
lambda: f_exe([x]))
def testLowerCompileArgTypeMismatch(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
"Computation compiled for input types:\n.*float32.*\n"
"called with:\n.*int32.*",
lambda: f_exe(x_i32))
class XMapTestSPMD(SPMDTestMixin, XMapTest):
"""Re-executes all basic tests with the SPMD partitioner enabled"""