mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
1fd1faa06c
commit
ff66c55709
15
jax/api.py
15
jax/api.py
@ -2000,6 +2000,7 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
|
||||
|
||||
def make_jaxpr(fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
return_shape: bool = False,
|
||||
) -> Callable[..., core.ClosedJaxpr]:
|
||||
"""Creates a function that produces its jaxpr given example args.
|
||||
@ -2009,6 +2010,12 @@ def make_jaxpr(fun: Callable,
|
||||
arguments and return value should be arrays, scalars, or standard Python
|
||||
containers (tuple/list/dict) thereof.
|
||||
static_argnums: See the :py:func:`jax.jit` docstring.
|
||||
axis_env: Optional, a sequence of pairs where the first element is an axis
|
||||
name and the second element is a positive integer representing the size of
|
||||
the mapped axis with that name. This parameter is useful when lowering
|
||||
functions that involve parallel communication collectives, and it
|
||||
specifies the axis name/size environment that would be set up by
|
||||
applications of :py:func:`jax.pmap`.
|
||||
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
||||
wrapped function returns a pair where the first element is the ``jaxpr``
|
||||
and the second element is a pytree with the same structure as
|
||||
@ -2069,8 +2076,14 @@ def make_jaxpr(fun: Callable,
|
||||
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in jax_args]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
|
||||
else:
|
||||
if axis_env:
|
||||
raise NotImplementedError(
|
||||
"axis_env argument to make_jaxpr only supported with omnistaging.")
|
||||
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
|
@ -2775,6 +2775,15 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
||||
self.assertEqual(shape_tree, expected)
|
||||
|
||||
def test_make_jaxpr_axis_env(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
|
||||
self.assertIn('psum', str(jaxpr))
|
||||
|
||||
|
||||
class LazyTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user