mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Properly count sublevels when tracing xmap body
Otherwise it can lead to tracer leak errors. I'm not a 100% sure how this works out, because the sublevel counting has changed since I read it previously. This replicates the changes applied to DynamicJaxprTrace.process_map since I last looked at it.
This commit is contained in:
parent
5d6f81cda8
commit
7439e1b1f8
@ -1015,8 +1015,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
|
||||
for a, a_in_axes in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, mapped_in_avals)
|
||||
with core.new_sublevel():
|
||||
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, mapped_in_avals)
|
||||
out_axes = params['out_axes_thunk']()
|
||||
if params['spmd_out_axes_thunk'] is not None:
|
||||
spmd_out_axes = params['spmd_out_axes_thunk']()
|
||||
|
@ -365,6 +365,19 @@ class XMapTest(XMapTestCase):
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
|
||||
def testNoTracerLeak(self):
|
||||
@jax.jit
|
||||
def xmap_linearize(xs):
|
||||
eye = jnp.eye(xs.shape[0], dtype=jnp.float32)
|
||||
primal, grad_f = jax.linearize(jnp.sin, xs)
|
||||
return maps.xmap(
|
||||
grad_f,
|
||||
in_axes=['i', ...],
|
||||
out_axes=['i', ...],
|
||||
axis_resources={'i': maps.SerialLoop(1)})(eye)
|
||||
xs = jnp.arange(1, 4, step=1).astype(jnp.float32)
|
||||
xmap_linearize(xs) # Doesn't raise a tracer leak error
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
|
Loading…
x
Reference in New Issue
Block a user