Properly count sublevels when tracing xmap body

Otherwise it can lead to tracer leak errors. I'm not a 100% sure how
this works out, because the sublevel counting has changed since I read
it previously. This replicates the changes applied to
DynamicJaxprTrace.process_map since I last looked at it.
This commit is contained in:
Adam Paszke 2022-07-05 11:15:39 +00:00
parent 5d6f81cda8
commit 7439e1b1f8
2 changed files with 16 additions and 2 deletions

View File

@ -1015,8 +1015,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
for a, a_in_axes in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd(global_axis_sizes.items()):
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, mapped_in_avals)
with core.new_sublevel():
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, mapped_in_avals)
out_axes = params['out_axes_thunk']()
if params['spmd_out_axes_thunk'] is not None:
spmd_out_axes = params['spmd_out_axes_thunk']()

View File

@ -365,6 +365,19 @@ class XMapTest(XMapTestCase):
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
def testNoTracerLeak(self):
@jax.jit
def xmap_linearize(xs):
eye = jnp.eye(xs.shape[0], dtype=jnp.float32)
primal, grad_f = jax.linearize(jnp.sin, xs)
return maps.xmap(
grad_f,
in_axes=['i', ...],
out_axes=['i', ...],
axis_resources={'i': maps.SerialLoop(1)})(eye)
xs = jnp.arange(1, 4, step=1).astype(jnp.float32)
xmap_linearize(xs) # Doesn't raise a tracer leak error
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (