Keep axis_env initialized during jaxpr_subcomp

``jaxpr_subcomp`` likes to lower control-flow primitives by tracing them
again as JAX callables, but they're all axis primitives now and so they
do require a properly initialized axis env.
This commit is contained in:
Adam Paszke 2021-10-01 11:12:14 +00:00
parent 16f0d51c85
commit 08685efb22
2 changed files with 25 additions and 15 deletions

View File

@ -696,21 +696,21 @@ def xla_computation(fun: Callable,
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
if out_parts is None:
out_parts_flat = None
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
c = xb.make_computation_builder(f"xla_computation_{fun_name}")
xla_consts = map(partial(xb.constant, c), consts)
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
out_nodes = xla.jaxpr_subcomp(
c, jaxpr, backend, axis_env_, xla_consts,
extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
if out_parts is None:
out_parts_flat = None
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
c = xb.make_computation_builder(f"xla_computation_{fun_name}")
xla_consts = map(partial(xb.constant, c), consts)
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
out_nodes = xla.jaxpr_subcomp(
c, jaxpr, backend, axis_env_, xla_consts,
extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
if out_parts is not None:
out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple)

View File

@ -1860,6 +1860,16 @@ class APITest(jtu.JaxTestCase):
axis_env = [(axis_name, api.local_device_count())]
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
def test_xla_computation_axis_env(self):
def fn(x):
z = x * jax.lax.axis_index('i').astype(jnp.float32)
def inner_fn(carry, a):
return carry + a, ()
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
x = jnp.ones((5, 6, 4))
_ = jax.xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
def test_concurrent_device_get_and_put(self):
def f(x):
for _ in range(100):