Add xla garbage collection to gc.callback

This commit is contained in:
Jake VanderPlas 2023-03-10 10:49:25 -08:00
parent 04def0b6ab
commit 36d2179a85
2 changed files with 19 additions and 2 deletions

View File

@ -15,10 +15,9 @@
# This module is largely a wrapper around `jaxlib` that performs version
# checking on import.
import platform
import gc
import re
import os
import warnings
from typing import Optional, Tuple
try:
@ -94,6 +93,11 @@ pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
# XLA garbage collection: see https://github.com/google/jax/issues/14882
def _xla_gc_callback(*args):
xla_client._xla.collect_garbage()
gc.callbacks.append(_xla_gc_callback)
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error

View File

@ -9629,5 +9629,18 @@ class AutodidaxTest(jtu.JaxTestCase):
autodidax_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(autodidax_module)
class GarbageCollectionTest(jtu.JaxTestCase):
def test_xla_gc_callback(self):
# https://github.com/google/jax/issues/14882
x_np = np.arange(10, dtype='int32')
x_jax = jax.device_put(x_np)
x_np_weakref = weakref.ref(x_np)
del x_np
del x_jax
gc.collect()
assert x_np_weakref() is None
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())