[XLA:Python] Fix bug where garbage collection was not being triggered sufficiently often.

The destructor ManagedPyObjects failed to increment the count of garbage objects, which meant that MaybeCollectGarbage() was not triggered. Since the C++ JIT path calls MaybeCollectGarbage(), repeated calls to a JIT-ted function that accumulate garbage might fail to ever collect.

PiperOrigin-RevId: 394050168
This commit is contained in:
Peter Hawkins 2021-08-31 11:36:40 -07:00 committed by jax authors
parent 5f46206898
commit 0cfa95cb4d

View File

@ -661,6 +661,27 @@ class CPPJitTest(jtu.BufferDonationTestCase):
np.testing.assert_allclose(f_pruned(*args), 3)
self.assertEqual(count[0], 1)
@unittest.skipIf(jax.lib._xla_extension_version <= 36,
"Test requires jaxlib 0.1.71")
def testBuffersAreFreedPromptly(self):
# Regression test for a bug where garbage collection was delayed too long
# for NumPy buffers that are aliased zero-copy by the runtime.
@self.jit
def f(x):
return x + 1
refs = []
x = np.ones((10000,), np.float32)
for step in range(1000):
x = f(x)
refs.append(weakref.ref(x))
x = np.asarray(x)
# We expect most of the input buffers to have been garbage
# collected in parallel with the execution. We can't call
# block_until_ready() here because it would force a garbage collection.
live_refs = len([ref for ref in refs if ref() is not None])
self.assertLessEqual(live_refs, 100)
class PythonJitTest(CPPJitTest):