add axis_env argument to make_jaxpr

fixes #5522
This commit is contained in:
Matthew Johnson 2021-01-26 17:25:22 -08:00
parent 1fd1faa06c
commit ff66c55709
2 changed files with 23 additions and 1 deletions

View File

@ -2000,6 +2000,7 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
def make_jaxpr(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
return_shape: bool = False,
) -> Callable[..., core.ClosedJaxpr]:
"""Creates a function that produces its jaxpr given example args.
@ -2009,6 +2010,12 @@ def make_jaxpr(fun: Callable,
arguments and return value should be arrays, scalars, or standard Python
containers (tuple/list/dict) thereof.
static_argnums: See the :py:func:`jax.jit` docstring.
axis_env: Optional, a sequence of pairs where the first element is an axis
name and the second element is a positive integer representing the size of
the mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the ``jaxpr``
and the second element is a pytree with the same structure as
@ -2069,8 +2076,14 @@ def make_jaxpr(fun: Callable,
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
in_avals = [raise_to_shaped(core.get_aval(x)) for x in jax_args]
if config.omnistaging_enabled:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
else:
if axis_env:
raise NotImplementedError(
"axis_env argument to make_jaxpr only supported with omnistaging.")
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore

View File

@ -2775,6 +2775,15 @@ class JaxprTest(jtu.JaxTestCase):
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
self.assertEqual(shape_tree, expected)
def test_make_jaxpr_axis_env(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
return x - lax.psum(x, 'i')
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
self.assertIn('psum', str(jaxpr))
class LazyTest(jtu.JaxTestCase):