fix leak checker internal error

The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).

But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
This commit is contained in:
Matthew Johnson 2022-09-23 11:24:13 -07:00
parent 0c085471c7
commit b6ef90ffdd
3 changed files with 36 additions and 5 deletions

View File

@ -1899,15 +1899,12 @@ class DynamicJaxprTrace(core.Trace):
custom_staging_rules: Dict[Primitive, Callable] = {}
def _memoize(thunk):
if config.jax_check_tracer_leaks:
return thunk
cell = []
saved_state = core.thread_local_state.trace_state.copy()
saved_state = [core.thread_local_state.trace_state.copy()]
def memoized():
if not cell:
prev_state = core.thread_local_state.trace_state
core.thread_local_state.trace_state = saved_state
core.thread_local_state.trace_state = saved_state.pop()
try:
cell.append(thunk())
finally:

View File

@ -3502,6 +3502,14 @@ class APITest(jtu.JaxTestCase):
return t(y)
s(3) # doesn't crash
def test_leak_checker_internal_error(self):
def apply_fn(inp):
fn = jax.checkpoint(lambda x: jax.nn.relu(1.0 * x))
return jax.vjp(fn, inp)
with jax.check_tracer_leaks():
jax.jit(apply_fn)(1.0) # don't crash
def test_default_backend(self):
first_local_device = api.local_devices()[0]
self.assertEqual(first_local_device.platform, api.default_backend())

View File

@ -211,6 +211,32 @@ class NNFunctionsTest(jtu.JaxTestCase):
with jax.checking_leaks():
fwd() # doesn't crash
def testCustomJVPLeak2(self):
# https://github.com/google/jax/issues/8171
# The above test uses jax.nn.sigmoid, as in the original #8171, but that
# function no longer actually has a custom_jvp! So we inline the old def.
@jax.custom_jvp
def sigmoid(x):
one = jnp.float32(1)
return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
@jax.jit
def fwd():
a = jnp.array(1., 'float32')
def f(hx, _):
hx = jax.nn.relu(hx + a)
return hx, None
hx = jnp.array(0., 'float32')
jax.lax.scan(f, hx, None, length=2)
with jax.checking_leaks():
fwd() # doesn't crash
InitializerRecord = collections.namedtuple(
"InitializerRecord",
["name", "initializer", "shapes", "dtypes"])