mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
ahead-of-time lowering and compilation frontend for xmap
This commit is contained in:
parent
d81114deff
commit
e53ec025c7
@ -26,7 +26,7 @@ from enum import Enum
|
||||
from jax import numpy as jnp
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api import _check_callable, _check_arg
|
||||
from jax._src.api import Lowered, _check_callable, _check_arg
|
||||
from jax._src import dispatch
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
tree_leaves)
|
||||
@ -639,7 +639,7 @@ def xmap(fun: Callable,
|
||||
spmd_in_axes=None,
|
||||
spmd_out_axes_thunk=None,
|
||||
positional_semantics=_positional_semantics)
|
||||
return fun_flat, args_flat, params, out_tree
|
||||
return fun_flat, args_flat, params, in_tree, out_tree
|
||||
|
||||
def verify_outputs(out_flat, out_tree, params):
|
||||
if has_output_rank_assertions:
|
||||
@ -652,7 +652,7 @@ def xmap(fun: Callable,
|
||||
|
||||
def fun_mapped(*args):
|
||||
tree_map(_check_arg, args)
|
||||
fun_flat, args_flat, params, out_tree = infer_params(*args)
|
||||
fun_flat, args_flat, params, _, out_tree = infer_params(*args)
|
||||
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
|
||||
return verify_outputs(out_flat, out_tree, params)
|
||||
|
||||
@ -662,13 +662,15 @@ def xmap(fun: Callable,
|
||||
return f
|
||||
|
||||
def lower(*args):
|
||||
fun_flat, args_flat, params, out_tree = infer_params(*args)
|
||||
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
|
||||
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
|
||||
return make_xmap_callable(
|
||||
computation = make_xmap_callable(
|
||||
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
|
||||
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
||||
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
||||
params['spmd_out_axes_thunk'], params['positional_semantics'], *avals_flat)
|
||||
return Lowered(
|
||||
computation, in_tree, out_tree(), donate_argnums, no_kwargs=True)
|
||||
|
||||
fun_mapped = wraps(fun)(decorate_serial(fun_mapped))
|
||||
fun_mapped.lower = decorate_serial(lower)
|
||||
|
@ -591,6 +591,32 @@ class XMapTest(XMapTestCase):
|
||||
# Make sure this doesn't crash
|
||||
xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]).lower(x)
|
||||
|
||||
def testLowerCompile(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
f_exe = f.lower(x).compile()
|
||||
self.assertAllClose(f_exe(x), f(x))
|
||||
|
||||
def testLowerCompileInTreeMismatch(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
f_exe = f.lower(x).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError, "function compiled for .*, called with .*",
|
||||
lambda: f_exe([x]))
|
||||
|
||||
def testLowerCompileArgTypeMismatch(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
x_f32 = x.astype(jnp.float32)
|
||||
x_i32 = x.astype(jnp.int32)
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Computation compiled for input types:\n.*float32.*\n"
|
||||
"called with:\n.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
|
||||
class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
"""Re-executes all basic tests with the SPMD partitioner enabled"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user