mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add xla garbage collection to gc.callback
This commit is contained in:
parent
04def0b6ab
commit
36d2179a85
@ -15,10 +15,9 @@
|
||||
# This module is largely a wrapper around `jaxlib` that performs version
|
||||
# checking on import.
|
||||
|
||||
import platform
|
||||
import gc
|
||||
import re
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
try:
|
||||
@ -94,6 +93,11 @@ pytree = xla_client._xla.pytree
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
pmap_lib = xla_client._xla.pmap_lib
|
||||
|
||||
# XLA garbage collection: see https://github.com/google/jax/issues/14882
|
||||
def _xla_gc_callback(*args):
|
||||
xla_client._xla.collect_garbage()
|
||||
gc.callbacks.append(_xla_gc_callback)
|
||||
|
||||
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
|
||||
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
|
||||
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
|
||||
|
@ -9629,5 +9629,18 @@ class AutodidaxTest(jtu.JaxTestCase):
|
||||
autodidax_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(autodidax_module)
|
||||
|
||||
class GarbageCollectionTest(jtu.JaxTestCase):
|
||||
def test_xla_gc_callback(self):
|
||||
# https://github.com/google/jax/issues/14882
|
||||
x_np = np.arange(10, dtype='int32')
|
||||
x_jax = jax.device_put(x_np)
|
||||
x_np_weakref = weakref.ref(x_np)
|
||||
|
||||
del x_np
|
||||
del x_jax
|
||||
gc.collect()
|
||||
|
||||
assert x_np_weakref() is None
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user