mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Keep axis_env initialized during jaxpr_subcomp
``jaxpr_subcomp`` likes to lower control-flow primitives by tracing them again as JAX callables, but they're all axis primitives now and so they do require a properly initialized axis env.
This commit is contained in:
parent
16f0d51c85
commit
08685efb22
@ -696,21 +696,21 @@ def xla_computation(fun: Callable,
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
if out_parts is None:
|
||||
out_parts_flat = None
|
||||
else:
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
c = xb.make_computation_builder(f"xla_computation_{fun_name}")
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
c, jaxpr, backend, axis_env_, xla_consts,
|
||||
extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
if out_parts is None:
|
||||
out_parts_flat = None
|
||||
else:
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
c = xb.make_computation_builder(f"xla_computation_{fun_name}")
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
c, jaxpr, backend, axis_env_, xla_consts,
|
||||
extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
|
||||
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
|
||||
if out_parts is not None:
|
||||
out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple)
|
||||
|
@ -1860,6 +1860,16 @@ class APITest(jtu.JaxTestCase):
|
||||
axis_env = [(axis_name, api.local_device_count())]
|
||||
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
|
||||
|
||||
def test_xla_computation_axis_env(self):
|
||||
def fn(x):
|
||||
z = x * jax.lax.axis_index('i').astype(jnp.float32)
|
||||
def inner_fn(carry, a):
|
||||
return carry + a, ()
|
||||
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
|
||||
|
||||
x = jnp.ones((5, 6, 4))
|
||||
_ = jax.xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
|
||||
|
||||
def test_concurrent_device_get_and_put(self):
|
||||
def f(x):
|
||||
for _ in range(100):
|
||||
|
Loading…
x
Reference in New Issue
Block a user