Add config option to log or fatal when jax.Arrays are GCed.

Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
This commit is contained in:
Ionel Gog 2024-10-17 12:22:39 -07:00 committed by jax authors
parent 1b5cf5a494
commit ec279f9c54
5 changed files with 169 additions and 23 deletions

View File

@ -25,8 +25,8 @@ import threading
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
from jax._src import lib
from jax._src.lib import guard_lib
from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client
from jax._src import logging_config
@ -1596,7 +1596,7 @@ jax_xla_profile_version = int_state(
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
state = transfer_guard_lib.thread_local_state()
state = guard_lib.thread_local_state()
prev = state.explicit_device_put
state.explicit_device_put = True
try:
@ -1607,7 +1607,7 @@ def explicit_device_put_scope() -> Iterator[None]:
@contextlib.contextmanager
def explicit_device_get_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_get() call."""
state = transfer_guard_lib.thread_local_state()
state = guard_lib.thread_local_state()
prev = state.explicit_device_get
state.explicit_device_get = True
try:
@ -1616,19 +1616,19 @@ def explicit_device_get_scope() -> Iterator[None]:
state.explicit_device_get = prev
def _update_transfer_guard(state, key, val):
"""Applies the transfer guard level within transfer_guard_lib."""
"""Applies the transfer guard level within guard_lib."""
if val is None:
setattr(state, key, None)
elif val == 'allow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
setattr(state, key, guard_lib.TransferGuardLevel.ALLOW)
elif val == 'log':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
setattr(state, key, guard_lib.TransferGuardLevel.LOG)
elif val == 'disallow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW)
elif val == 'log_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
setattr(state, key, guard_lib.TransferGuardLevel.LOG_EXPLICIT)
elif val == 'disallow_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
else:
assert False, f'Invalid transfer guard level {val}'
@ -1637,45 +1637,46 @@ transfer_guard_host_to_device = optional_enum_state(
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# The default is applied by guard_lib. Use None here to avoid accidentally
# overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for host-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'host_to_device', val),
guard_lib.global_state(), 'host_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
guard_lib.thread_local_state(), 'host_to_device', val))
transfer_guard_device_to_device = optional_enum_state(
name='jax_transfer_guard_device_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# The default is applied by guard_lib. Use None here to avoid accidentally
# overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_device', val),
guard_lib.global_state(), 'device_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
guard_lib.thread_local_state(), 'device_to_device', val))
transfer_guard_device_to_host = optional_enum_state(
name='jax_transfer_guard_device_to_host',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# The default is applied by guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-host transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_host', val),
guard_lib.global_state(), 'device_to_host', val
),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
guard_lib.thread_local_state(), 'device_to_host', val))
def _update_all_transfer_guard_global(val):
for name in ('jax_transfer_guard_host_to_device',
@ -1688,8 +1689,8 @@ _transfer_guard = optional_enum_state(
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard_*.
# The default is applied by guard_lib. Use None here to avoid accidentally
# overriding --jax_transfer_guard_*.
default=None,
help=('Select the transfer guard level for all transfers. This option is '
'set-only; the transfer guard level for a specific direction should '
@ -1718,6 +1719,52 @@ def transfer_guard(new_val: str) -> Iterator[None]:
yield
if lib.xla_extension_version < 293:
def array_garbage_collection_guard(_val):
raise NotImplementedError(
'jaxlib version is too low for garbage collection guard'
)
else:
def _update_garbage_collection_guard(state, key, val):
"""Applies the transfer guard level within guard_lib."""
if val is None:
setattr(state, key, None)
elif val == 'allow':
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW)
elif val == 'log':
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG)
elif val == 'fatal':
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL)
else:
assert False, f'Invalid garbage collection guard level {val}'
array_garbage_collection_guard = optional_enum_state(
name='jax_array_garbage_collection_guard',
enum_values=['allow', 'log', 'fatal'],
# The default is applied by guard_lib.
default=None,
help=(
'Select garbage collection guard level for "jax.Array" objects.\nThis'
' option can be used to control what happens when a "jax.Array"'
' object is garbage collected. It is desirable for "jax.Array"'
' objects to be freed by Python reference couting rather than garbage'
' collection in order to avoid device memory being held by the arrays'
' until garbage collection occurs.\n\nValid values are:\n * "allow":'
' do not log garbage collection of "jax.Array" objects.\n * "log":'
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
' "allow".'
),
update_global_hook=lambda val: _update_garbage_collection_guard(
guard_lib.global_state(), 'garbage_collect_array', val
),
update_thread_local_hook=lambda val: _update_garbage_collection_guard(
guard_lib.thread_local_state(), 'garbage_collect_array', val
),
)
def _update_debug_log_modules(module_names_str: str | None):
logging_config.disable_all_debug_logging()
if not module_names_str:

View File

@ -155,6 +155,9 @@ def _cuda_path() -> str | None:
cuda_path = _cuda_path()
transfer_guard_lib = xla_client._xla.transfer_guard_lib
if version >= (0, 4, 35):
guard_lib = xla_client._xla.guard_lib
else:
guard_lib = xla_client._xla.transfer_guard_lib
Device = xla_client._xla.Device

View File

@ -100,6 +100,7 @@ def patch_copy_mlir_import(src_file, dst_dir):
_XLA_EXTENSION_STUBS = [
"__init__.pyi",
"guard_lib.pyi",
"ifrt_programs.pyi",
"ifrt_proxy.pyi",
"jax_jit.pyi",

View File

@ -1203,6 +1203,11 @@ jax_multiplatform_test(
srcs = ["transfer_guard_test.py"],
)
jax_multiplatform_test(
name = "garbage_collection_guard_test",
srcs = ["garbage_collection_guard_test.py"],
)
jax_multiplatform_test(
name = "name_stack_test",
srcs = ["name_stack_test.py"],

View File

@ -0,0 +1,90 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for garbage allocation guard."""
import gc
import io
from unittest import mock
from absl.testing import absltest
import jax
from jax._src import config
from jax._src.lib import xla_extension_version
import jax._src.test_util as jtu
import jax.numpy as jnp
jax.config.parse_flags_with_absl()
# Helper class used to create a reference cycle.
class GarbageCollectionGuardTestNodeHelper:
def __init__(self, data):
self.data = data
self.next = None
def _create_array_cycle():
"""Creates a reference cycle of two jax.Arrays."""
n1 = GarbageCollectionGuardTestNodeHelper(jnp.ones((2, 2)))
n2 = GarbageCollectionGuardTestNodeHelper(jnp.zeros((2, 2)))
n1.next = n2
n2.next = n1
class GarbageCollectionGuardTest(jtu.JaxTestCase):
def test_gced_array_is_not_logged_by_default(self):
if xla_extension_version < 293:
self.skipTest("Requires xla_extension_version >= 293")
# Create a reference cycle of two jax.Arrays.
_create_array_cycle()
# Use mock_stderr to be able to inspect stderr.
mock_stderr = io.StringIO()
with mock.patch("sys.stderr", mock_stderr):
# Trigger a garbage collection, which will garbage collect the arrays
# in the cycle.
gc.collect()
# Check that no error message is logged because
# `array_garbage_collection_guard` defaults to `allow`.
self.assertNotIn(
"`jax.Array` was deleted by the Python garbage collector",
mock_stderr.getvalue(),
)
def test_gced_array_is_logged(self):
if xla_extension_version < 293:
self.skipTest("Requires xla_extension_version >= 293")
# Use mock_stderr to be able to inspect stderr.
mock_stderr = io.StringIO()
with config.array_garbage_collection_guard("log"):
# Create a reference cycle of two jax.Arrays.
_create_array_cycle()
with mock.patch("sys.stderr", mock_stderr):
gc.collect()
# Verify that an error message is logged because two jax.Arrays were garbage
# collected.
self.assertIn(
"`jax.Array` was deleted by the Python garbage collector",
mock_stderr.getvalue(),
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())