diff --git a/jax/_src/config.py b/jax/_src/config.py index 5a0c80a4f..8bebd7d90 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -25,8 +25,8 @@ import threading from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast from jax._src import lib +from jax._src.lib import guard_lib from jax._src.lib import jax_jit -from jax._src.lib import transfer_guard_lib from jax._src.lib import xla_client from jax._src import logging_config @@ -1596,7 +1596,7 @@ jax_xla_profile_version = int_state( @contextlib.contextmanager def explicit_device_put_scope() -> Iterator[None]: """Indicates that the current context is an explicit device_put*() call.""" - state = transfer_guard_lib.thread_local_state() + state = guard_lib.thread_local_state() prev = state.explicit_device_put state.explicit_device_put = True try: @@ -1607,7 +1607,7 @@ def explicit_device_put_scope() -> Iterator[None]: @contextlib.contextmanager def explicit_device_get_scope() -> Iterator[None]: """Indicates that the current context is an explicit device_get() call.""" - state = transfer_guard_lib.thread_local_state() + state = guard_lib.thread_local_state() prev = state.explicit_device_get state.explicit_device_get = True try: @@ -1616,19 +1616,19 @@ def explicit_device_get_scope() -> Iterator[None]: state.explicit_device_get = prev def _update_transfer_guard(state, key, val): - """Applies the transfer guard level within transfer_guard_lib.""" + """Applies the transfer guard level within guard_lib.""" if val is None: setattr(state, key, None) elif val == 'allow': - setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW) + setattr(state, key, guard_lib.TransferGuardLevel.ALLOW) elif val == 'log': - setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG) + setattr(state, key, guard_lib.TransferGuardLevel.LOG) elif val == 'disallow': - setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW) + setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW) elif val == 'log_explicit': - setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT) + setattr(state, key, guard_lib.TransferGuardLevel.LOG_EXPLICIT) elif val == 'disallow_explicit': - setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT) + setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT) else: assert False, f'Invalid transfer guard level {val}' @@ -1637,45 +1637,46 @@ transfer_guard_host_to_device = optional_enum_state( enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' ], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. + # The default is applied by guard_lib. Use None here to avoid accidentally + # overriding --jax_transfer_guard. default=None, help=('Select the transfer guard level for host-to-device transfers. ' 'Default is "allow".'), update_global_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.global_state(), 'host_to_device', val), + guard_lib.global_state(), 'host_to_device', val), update_thread_local_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.thread_local_state(), 'host_to_device', val)) + guard_lib.thread_local_state(), 'host_to_device', val)) transfer_guard_device_to_device = optional_enum_state( name='jax_transfer_guard_device_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' ], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. + # The default is applied by guard_lib. Use None here to avoid accidentally + # overriding --jax_transfer_guard. default=None, help=('Select the transfer guard level for device-to-device transfers. ' 'Default is "allow".'), update_global_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.global_state(), 'device_to_device', val), + guard_lib.global_state(), 'device_to_device', val), update_thread_local_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.thread_local_state(), 'device_to_device', val)) + guard_lib.thread_local_state(), 'device_to_device', val)) transfer_guard_device_to_host = optional_enum_state( name='jax_transfer_guard_device_to_host', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' ], - # The default is applied by transfer_guard_lib. Use None here to avoid + # The default is applied by guard_lib. Use None here to avoid # accidentally overriding --jax_transfer_guard. default=None, help=('Select the transfer guard level for device-to-host transfers. ' 'Default is "allow".'), update_global_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.global_state(), 'device_to_host', val), + guard_lib.global_state(), 'device_to_host', val + ), update_thread_local_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.thread_local_state(), 'device_to_host', val)) + guard_lib.thread_local_state(), 'device_to_host', val)) def _update_all_transfer_guard_global(val): for name in ('jax_transfer_guard_host_to_device', @@ -1688,8 +1689,8 @@ _transfer_guard = optional_enum_state( enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' ], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard_*. + # The default is applied by guard_lib. Use None here to avoid accidentally + # overriding --jax_transfer_guard_*. default=None, help=('Select the transfer guard level for all transfers. This option is ' 'set-only; the transfer guard level for a specific direction should ' @@ -1718,6 +1719,52 @@ def transfer_guard(new_val: str) -> Iterator[None]: yield +if lib.xla_extension_version < 293: + + def array_garbage_collection_guard(_val): + raise NotImplementedError( + 'jaxlib version is too low for garbage collection guard' + ) + +else: + def _update_garbage_collection_guard(state, key, val): + """Applies the transfer guard level within guard_lib.""" + if val is None: + setattr(state, key, None) + elif val == 'allow': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW) + elif val == 'log': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG) + elif val == 'fatal': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL) + else: + assert False, f'Invalid garbage collection guard level {val}' + + array_garbage_collection_guard = optional_enum_state( + name='jax_array_garbage_collection_guard', + enum_values=['allow', 'log', 'fatal'], + # The default is applied by guard_lib. + default=None, + help=( + 'Select garbage collection guard level for "jax.Array" objects.\nThis' + ' option can be used to control what happens when a "jax.Array"' + ' object is garbage collected. It is desirable for "jax.Array"' + ' objects to be freed by Python reference couting rather than garbage' + ' collection in order to avoid device memory being held by the arrays' + ' until garbage collection occurs.\n\nValid values are:\n * "allow":' + ' do not log garbage collection of "jax.Array" objects.\n * "log":' + ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' + ' fatal error if a "jax.Array" is garbage collected.\nDefault is' + ' "allow".' + ), + update_global_hook=lambda val: _update_garbage_collection_guard( + guard_lib.global_state(), 'garbage_collect_array', val + ), + update_thread_local_hook=lambda val: _update_garbage_collection_guard( + guard_lib.thread_local_state(), 'garbage_collect_array', val + ), + ) + def _update_debug_log_modules(module_names_str: str | None): logging_config.disable_all_debug_logging() if not module_names_str: diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 68a0f1553..ea9191b2c 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -155,6 +155,9 @@ def _cuda_path() -> str | None: cuda_path = _cuda_path() -transfer_guard_lib = xla_client._xla.transfer_guard_lib +if version >= (0, 4, 35): + guard_lib = xla_client._xla.guard_lib +else: + guard_lib = xla_client._xla.transfer_guard_lib Device = xla_client._xla.Device diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index cbbce31f1..5ebdf6e4c 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -100,6 +100,7 @@ def patch_copy_mlir_import(src_file, dst_dir): _XLA_EXTENSION_STUBS = [ "__init__.pyi", + "guard_lib.pyi", "ifrt_programs.pyi", "ifrt_proxy.pyi", "jax_jit.pyi", diff --git a/tests/BUILD b/tests/BUILD index 615437ce4..dc5d5b373 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1203,6 +1203,11 @@ jax_multiplatform_test( srcs = ["transfer_guard_test.py"], ) +jax_multiplatform_test( + name = "garbage_collection_guard_test", + srcs = ["garbage_collection_guard_test.py"], +) + jax_multiplatform_test( name = "name_stack_test", srcs = ["name_stack_test.py"], diff --git a/tests/garbage_collection_guard_test.py b/tests/garbage_collection_guard_test.py new file mode 100644 index 000000000..d23d239dd --- /dev/null +++ b/tests/garbage_collection_guard_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for garbage allocation guard.""" + +import gc +import io +from unittest import mock + +from absl.testing import absltest +import jax +from jax._src import config +from jax._src.lib import xla_extension_version +import jax._src.test_util as jtu +import jax.numpy as jnp + +jax.config.parse_flags_with_absl() + + +# Helper class used to create a reference cycle. +class GarbageCollectionGuardTestNodeHelper: + + def __init__(self, data): + self.data = data + self.next = None + + +def _create_array_cycle(): + """Creates a reference cycle of two jax.Arrays.""" + n1 = GarbageCollectionGuardTestNodeHelper(jnp.ones((2, 2))) + n2 = GarbageCollectionGuardTestNodeHelper(jnp.zeros((2, 2))) + n1.next = n2 + n2.next = n1 + + +class GarbageCollectionGuardTest(jtu.JaxTestCase): + + def test_gced_array_is_not_logged_by_default(self): + if xla_extension_version < 293: + self.skipTest("Requires xla_extension_version >= 293") + + # Create a reference cycle of two jax.Arrays. + _create_array_cycle() + + # Use mock_stderr to be able to inspect stderr. + mock_stderr = io.StringIO() + with mock.patch("sys.stderr", mock_stderr): + # Trigger a garbage collection, which will garbage collect the arrays + # in the cycle. + gc.collect() + # Check that no error message is logged because + # `array_garbage_collection_guard` defaults to `allow`. + self.assertNotIn( + "`jax.Array` was deleted by the Python garbage collector", + mock_stderr.getvalue(), + ) + + def test_gced_array_is_logged(self): + if xla_extension_version < 293: + self.skipTest("Requires xla_extension_version >= 293") + + # Use mock_stderr to be able to inspect stderr. + mock_stderr = io.StringIO() + + with config.array_garbage_collection_guard("log"): + # Create a reference cycle of two jax.Arrays. + _create_array_cycle() + with mock.patch("sys.stderr", mock_stderr): + gc.collect() + + # Verify that an error message is logged because two jax.Arrays were garbage + # collected. + self.assertIn( + "`jax.Array` was deleted by the Python garbage collector", + mock_stderr.getvalue(), + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())