mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test. PiperOrigin-RevId: 462239974
This commit is contained in:
parent
d8cbb29d14
commit
c0ec3b33e6
@ -63,6 +63,7 @@ from jax._src.api import (
|
||||
block_until_ready,
|
||||
checkpoint as checkpoint,
|
||||
checkpoint_policies as checkpoint_policies,
|
||||
clear_backends as clear_backends,
|
||||
closure_convert as closure_convert,
|
||||
curry, # TODO(phawkins): update users to avoid this.
|
||||
custom_gradient as custom_gradient,
|
||||
|
@ -3343,3 +3343,20 @@ def block_until_ready(x):
|
||||
except AttributeError:
|
||||
return x
|
||||
return jax.tree_util.tree_map(try_to_block, x)
|
||||
|
||||
|
||||
def clear_backends():
|
||||
"""
|
||||
Clear all backend clients so that new backend clients can be created later.
|
||||
"""
|
||||
|
||||
if xc._version < 79:
|
||||
raise RuntimeError("clear_backends is not supported in the jaxlib used."
|
||||
"Please update your jaxlib package.")
|
||||
|
||||
xb._clear_backends()
|
||||
jax.lib.xla_bridge._backends = {}
|
||||
dispatch.xla_callable.cache_clear() # type: ignore
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
_cpp_jit_cache.clear()
|
||||
jax_jit.CompiledFunctionCache.clear_all()
|
||||
|
@ -337,6 +337,20 @@ def backends():
|
||||
return _backends
|
||||
|
||||
|
||||
def _clear_backends():
|
||||
global _backends
|
||||
global _backends_errors
|
||||
global _default_backend
|
||||
|
||||
logging.info("Clearing JAX backend caches.")
|
||||
with _backend_lock:
|
||||
_backends = {}
|
||||
_backends_errors = {}
|
||||
_default_backend = None
|
||||
|
||||
get_backend.cache_clear()
|
||||
|
||||
|
||||
def _init_backend(platform):
|
||||
factory, unused_priority = _backend_factories.get(platform, (None, None))
|
||||
if factory is None:
|
||||
|
@ -887,6 +887,11 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "clear_backends_test",
|
||||
srcs = ["clear_backends_test.py"],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
40
tests/clear_backends_test.py
Normal file
40
tests/clear_backends_test.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for release_backend_clients."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ClearBackendsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_clear_backends(self):
|
||||
g = jax.jit(lambda x, y: x * y)
|
||||
self.assertEqual(g(1, 2), 2)
|
||||
if xc._version >= 79:
|
||||
self.assertNotEmpty(xb.get_backend().live_executables())
|
||||
jax.clear_backends()
|
||||
self.assertEmpty(xb.get_backend().live_executables())
|
||||
self.assertEqual(g(1, 2), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user