mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix leak checker internal error
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made into an identity function by enabling config.jax_check_tracer_leaks (from references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger the leak checker (which would see if any references to the main trace persisted after finishing tracing of the user function). But after #7345, the leak checker should only trigger when actual Tracers are leaked. So disabling the memoization when jax_check_tracer_leaks is no longer active shouldn't be necessary. (These PR numbers seem out of order! We're not sure why.) Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
This commit is contained in:
parent
0c085471c7
commit
b6ef90ffdd
@ -1899,15 +1899,12 @@ class DynamicJaxprTrace(core.Trace):
|
||||
custom_staging_rules: Dict[Primitive, Callable] = {}
|
||||
|
||||
def _memoize(thunk):
|
||||
if config.jax_check_tracer_leaks:
|
||||
return thunk
|
||||
|
||||
cell = []
|
||||
saved_state = core.thread_local_state.trace_state.copy()
|
||||
saved_state = [core.thread_local_state.trace_state.copy()]
|
||||
def memoized():
|
||||
if not cell:
|
||||
prev_state = core.thread_local_state.trace_state
|
||||
core.thread_local_state.trace_state = saved_state
|
||||
core.thread_local_state.trace_state = saved_state.pop()
|
||||
try:
|
||||
cell.append(thunk())
|
||||
finally:
|
||||
|
@ -3502,6 +3502,14 @@ class APITest(jtu.JaxTestCase):
|
||||
return t(y)
|
||||
s(3) # doesn't crash
|
||||
|
||||
def test_leak_checker_internal_error(self):
|
||||
def apply_fn(inp):
|
||||
fn = jax.checkpoint(lambda x: jax.nn.relu(1.0 * x))
|
||||
return jax.vjp(fn, inp)
|
||||
|
||||
with jax.check_tracer_leaks():
|
||||
jax.jit(apply_fn)(1.0) # don't crash
|
||||
|
||||
def test_default_backend(self):
|
||||
first_local_device = api.local_devices()[0]
|
||||
self.assertEqual(first_local_device.platform, api.default_backend())
|
||||
|
@ -211,6 +211,32 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
with jax.checking_leaks():
|
||||
fwd() # doesn't crash
|
||||
|
||||
def testCustomJVPLeak2(self):
|
||||
# https://github.com/google/jax/issues/8171
|
||||
# The above test uses jax.nn.sigmoid, as in the original #8171, but that
|
||||
# function no longer actually has a custom_jvp! So we inline the old def.
|
||||
|
||||
@jax.custom_jvp
|
||||
def sigmoid(x):
|
||||
one = jnp.float32(1)
|
||||
return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
|
||||
sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
|
||||
|
||||
@jax.jit
|
||||
def fwd():
|
||||
a = jnp.array(1., 'float32')
|
||||
|
||||
def f(hx, _):
|
||||
hx = jax.nn.relu(hx + a)
|
||||
return hx, None
|
||||
|
||||
hx = jnp.array(0., 'float32')
|
||||
jax.lax.scan(f, hx, None, length=2)
|
||||
|
||||
with jax.checking_leaks():
|
||||
fwd() # doesn't crash
|
||||
|
||||
|
||||
InitializerRecord = collections.namedtuple(
|
||||
"InitializerRecord",
|
||||
["name", "initializer", "shapes", "dtypes"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user