mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are: * allow: `jax.Array`s are allowed to be garbage collected. This is the default value. * log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback. * fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free. PiperOrigin-RevId: 687003464
This commit is contained in:
parent
1b5cf5a494
commit
ec279f9c54
@ -25,8 +25,8 @@ import threading
|
||||
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import guard_lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import logging_config
|
||||
|
||||
@ -1596,7 +1596,7 @@ jax_xla_profile_version = int_state(
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
state = guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_put
|
||||
state.explicit_device_put = True
|
||||
try:
|
||||
@ -1607,7 +1607,7 @@ def explicit_device_put_scope() -> Iterator[None]:
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_get_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_get() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
state = guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_get
|
||||
state.explicit_device_get = True
|
||||
try:
|
||||
@ -1616,19 +1616,19 @@ def explicit_device_get_scope() -> Iterator[None]:
|
||||
state.explicit_device_get = prev
|
||||
|
||||
def _update_transfer_guard(state, key, val):
|
||||
"""Applies the transfer guard level within transfer_guard_lib."""
|
||||
"""Applies the transfer guard level within guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
|
||||
setattr(state, key, guard_lib.TransferGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
|
||||
setattr(state, key, guard_lib.TransferGuardLevel.LOG)
|
||||
elif val == 'disallow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
|
||||
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW)
|
||||
elif val == 'log_explicit':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
||||
setattr(state, key, guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
||||
elif val == 'disallow_explicit':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
||||
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
||||
else:
|
||||
assert False, f'Invalid transfer guard level {val}'
|
||||
|
||||
@ -1637,45 +1637,46 @@ transfer_guard_host_to_device = optional_enum_state(
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
# The default is applied by guard_lib. Use None here to avoid accidentally
|
||||
# overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for host-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'host_to_device', val),
|
||||
guard_lib.global_state(), 'host_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
|
||||
guard_lib.thread_local_state(), 'host_to_device', val))
|
||||
|
||||
transfer_guard_device_to_device = optional_enum_state(
|
||||
name='jax_transfer_guard_device_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
# The default is applied by guard_lib. Use None here to avoid accidentally
|
||||
# overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_device', val),
|
||||
guard_lib.global_state(), 'device_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
|
||||
guard_lib.thread_local_state(), 'device_to_device', val))
|
||||
|
||||
transfer_guard_device_to_host = optional_enum_state(
|
||||
name='jax_transfer_guard_device_to_host',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# The default is applied by guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-host transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_host', val),
|
||||
guard_lib.global_state(), 'device_to_host', val
|
||||
),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
|
||||
guard_lib.thread_local_state(), 'device_to_host', val))
|
||||
|
||||
def _update_all_transfer_guard_global(val):
|
||||
for name in ('jax_transfer_guard_host_to_device',
|
||||
@ -1688,8 +1689,8 @@ _transfer_guard = optional_enum_state(
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard_*.
|
||||
# The default is applied by guard_lib. Use None here to avoid accidentally
|
||||
# overriding --jax_transfer_guard_*.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for all transfers. This option is '
|
||||
'set-only; the transfer guard level for a specific direction should '
|
||||
@ -1718,6 +1719,52 @@ def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
yield
|
||||
|
||||
|
||||
if lib.xla_extension_version < 293:
|
||||
|
||||
def array_garbage_collection_guard(_val):
|
||||
raise NotImplementedError(
|
||||
'jaxlib version is too low for garbage collection guard'
|
||||
)
|
||||
|
||||
else:
|
||||
def _update_garbage_collection_guard(state, key, val):
|
||||
"""Applies the transfer guard level within guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG)
|
||||
elif val == 'fatal':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL)
|
||||
else:
|
||||
assert False, f'Invalid garbage collection guard level {val}'
|
||||
|
||||
array_garbage_collection_guard = optional_enum_state(
|
||||
name='jax_array_garbage_collection_guard',
|
||||
enum_values=['allow', 'log', 'fatal'],
|
||||
# The default is applied by guard_lib.
|
||||
default=None,
|
||||
help=(
|
||||
'Select garbage collection guard level for "jax.Array" objects.\nThis'
|
||||
' option can be used to control what happens when a "jax.Array"'
|
||||
' object is garbage collected. It is desirable for "jax.Array"'
|
||||
' objects to be freed by Python reference couting rather than garbage'
|
||||
' collection in order to avoid device memory being held by the arrays'
|
||||
' until garbage collection occurs.\n\nValid values are:\n * "allow":'
|
||||
' do not log garbage collection of "jax.Array" objects.\n * "log":'
|
||||
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
|
||||
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
|
||||
' "allow".'
|
||||
),
|
||||
update_global_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.global_state(), 'garbage_collect_array', val
|
||||
),
|
||||
update_thread_local_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.thread_local_state(), 'garbage_collect_array', val
|
||||
),
|
||||
)
|
||||
|
||||
def _update_debug_log_modules(module_names_str: str | None):
|
||||
logging_config.disable_all_debug_logging()
|
||||
if not module_names_str:
|
||||
|
@ -155,6 +155,9 @@ def _cuda_path() -> str | None:
|
||||
|
||||
cuda_path = _cuda_path()
|
||||
|
||||
transfer_guard_lib = xla_client._xla.transfer_guard_lib
|
||||
if version >= (0, 4, 35):
|
||||
guard_lib = xla_client._xla.guard_lib
|
||||
else:
|
||||
guard_lib = xla_client._xla.transfer_guard_lib
|
||||
|
||||
Device = xla_client._xla.Device
|
||||
|
@ -100,6 +100,7 @@ def patch_copy_mlir_import(src_file, dst_dir):
|
||||
|
||||
_XLA_EXTENSION_STUBS = [
|
||||
"__init__.pyi",
|
||||
"guard_lib.pyi",
|
||||
"ifrt_programs.pyi",
|
||||
"ifrt_proxy.pyi",
|
||||
"jax_jit.pyi",
|
||||
|
@ -1203,6 +1203,11 @@ jax_multiplatform_test(
|
||||
srcs = ["transfer_guard_test.py"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "garbage_collection_guard_test",
|
||||
srcs = ["garbage_collection_guard_test.py"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "name_stack_test",
|
||||
srcs = ["name_stack_test.py"],
|
||||
|
90
tests/garbage_collection_guard_test.py
Normal file
90
tests/garbage_collection_guard_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for garbage allocation guard."""
|
||||
|
||||
import gc
|
||||
import io
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# Helper class used to create a reference cycle.
|
||||
class GarbageCollectionGuardTestNodeHelper:
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.next = None
|
||||
|
||||
|
||||
def _create_array_cycle():
|
||||
"""Creates a reference cycle of two jax.Arrays."""
|
||||
n1 = GarbageCollectionGuardTestNodeHelper(jnp.ones((2, 2)))
|
||||
n2 = GarbageCollectionGuardTestNodeHelper(jnp.zeros((2, 2)))
|
||||
n1.next = n2
|
||||
n2.next = n1
|
||||
|
||||
|
||||
class GarbageCollectionGuardTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gced_array_is_not_logged_by_default(self):
|
||||
if xla_extension_version < 293:
|
||||
self.skipTest("Requires xla_extension_version >= 293")
|
||||
|
||||
# Create a reference cycle of two jax.Arrays.
|
||||
_create_array_cycle()
|
||||
|
||||
# Use mock_stderr to be able to inspect stderr.
|
||||
mock_stderr = io.StringIO()
|
||||
with mock.patch("sys.stderr", mock_stderr):
|
||||
# Trigger a garbage collection, which will garbage collect the arrays
|
||||
# in the cycle.
|
||||
gc.collect()
|
||||
# Check that no error message is logged because
|
||||
# `array_garbage_collection_guard` defaults to `allow`.
|
||||
self.assertNotIn(
|
||||
"`jax.Array` was deleted by the Python garbage collector",
|
||||
mock_stderr.getvalue(),
|
||||
)
|
||||
|
||||
def test_gced_array_is_logged(self):
|
||||
if xla_extension_version < 293:
|
||||
self.skipTest("Requires xla_extension_version >= 293")
|
||||
|
||||
# Use mock_stderr to be able to inspect stderr.
|
||||
mock_stderr = io.StringIO()
|
||||
|
||||
with config.array_garbage_collection_guard("log"):
|
||||
# Create a reference cycle of two jax.Arrays.
|
||||
_create_array_cycle()
|
||||
with mock.patch("sys.stderr", mock_stderr):
|
||||
gc.collect()
|
||||
|
||||
# Verify that an error message is logged because two jax.Arrays were garbage
|
||||
# collected.
|
||||
self.assertIn(
|
||||
"`jax.Array` was deleted by the Python garbage collector",
|
||||
mock_stderr.getvalue(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user