Introduce jax.experimental.clear_backends to delete all JAX runtime backends.

In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
This commit is contained in:
Kuangyuan Chen 2022-07-20 15:09:47 -07:00 committed by jax authors
parent d8cbb29d14
commit c0ec3b33e6
5 changed files with 77 additions and 0 deletions

View File

@ -63,6 +63,7 @@ from jax._src.api import (
block_until_ready,
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
clear_backends as clear_backends,
closure_convert as closure_convert,
curry, # TODO(phawkins): update users to avoid this.
custom_gradient as custom_gradient,

View File

@ -3343,3 +3343,20 @@ def block_until_ready(x):
except AttributeError:
return x
return jax.tree_util.tree_map(try_to_block, x)
def clear_backends():
"""
Clear all backend clients so that new backend clients can be created later.
"""
if xc._version < 79:
raise RuntimeError("clear_backends is not supported in the jaxlib used."
"Please update your jaxlib package.")
xb._clear_backends()
jax.lib.xla_bridge._backends = {}
dispatch.xla_callable.cache_clear() # type: ignore
dispatch.xla_primitive_callable.cache_clear()
_cpp_jit_cache.clear()
jax_jit.CompiledFunctionCache.clear_all()

View File

@ -337,6 +337,20 @@ def backends():
return _backends
def _clear_backends():
global _backends
global _backends_errors
global _default_backend
logging.info("Clearing JAX backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
_default_backend = None
get_backend.cache_clear()
def _init_backend(platform):
factory, unused_priority = _backend_factories.get(platform, (None, None))
if factory is None:

View File

@ -887,6 +887,11 @@ jax_test(
],
)
jax_test(
name = "clear_backends_test",
srcs = ["clear_backends_test.py"],
)
exports_files(
[
"api_test.py",

View File

@ -0,0 +1,40 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for release_backend_clients."""
from absl.testing import absltest
import jax
from jax.config import config
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
config.parse_flags_with_absl()
class ClearBackendsTest(jtu.JaxTestCase):
def test_clear_backends(self):
g = jax.jit(lambda x, y: x * y)
self.assertEqual(g(1, 2), 2)
if xc._version >= 79:
self.assertNotEmpty(xb.get_backend().live_executables())
jax.clear_backends()
self.assertEmpty(xb.get_backend().live_executables())
self.assertEqual(g(1, 2), 2)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())