From c0ec3b33e687ce37b431906109d4a2bc4655285f Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Wed, 20 Jul 2022 15:09:47 -0700 Subject: [PATCH] Introduce jax.experimental.clear_backends to delete all JAX runtime backends. In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test. PiperOrigin-RevId: 462239974 --- jax/__init__.py | 1 + jax/_src/api.py | 17 +++++++++++++++ jax/_src/lib/xla_bridge.py | 14 +++++++++++++ tests/BUILD | 5 +++++ tests/clear_backends_test.py | 40 ++++++++++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+) create mode 100644 tests/clear_backends_test.py diff --git a/jax/__init__.py b/jax/__init__.py index 6b5f8a82d..57c73bb08 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -63,6 +63,7 @@ from jax._src.api import ( block_until_ready, checkpoint as checkpoint, checkpoint_policies as checkpoint_policies, + clear_backends as clear_backends, closure_convert as closure_convert, curry, # TODO(phawkins): update users to avoid this. custom_gradient as custom_gradient, diff --git a/jax/_src/api.py b/jax/_src/api.py index e12d1f517..0723c7e4a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3343,3 +3343,20 @@ def block_until_ready(x): except AttributeError: return x return jax.tree_util.tree_map(try_to_block, x) + + +def clear_backends(): + """ + Clear all backend clients so that new backend clients can be created later. + """ + + if xc._version < 79: + raise RuntimeError("clear_backends is not supported in the jaxlib used." + "Please update your jaxlib package.") + + xb._clear_backends() + jax.lib.xla_bridge._backends = {} + dispatch.xla_callable.cache_clear() # type: ignore + dispatch.xla_primitive_callable.cache_clear() + _cpp_jit_cache.clear() + jax_jit.CompiledFunctionCache.clear_all() diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index a971b5a87..256684be6 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -337,6 +337,20 @@ def backends(): return _backends +def _clear_backends(): + global _backends + global _backends_errors + global _default_backend + + logging.info("Clearing JAX backend caches.") + with _backend_lock: + _backends = {} + _backends_errors = {} + _default_backend = None + + get_backend.cache_clear() + + def _init_backend(platform): factory, unused_priority = _backend_factories.get(platform, (None, None)) if factory is None: diff --git a/tests/BUILD b/tests/BUILD index 7a65df699..1f7cedc21 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -887,6 +887,11 @@ jax_test( ], ) +jax_test( + name = "clear_backends_test", + srcs = ["clear_backends_test.py"], +) + exports_files( [ "api_test.py", diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py new file mode 100644 index 000000000..66639f10b --- /dev/null +++ b/tests/clear_backends_test.py @@ -0,0 +1,40 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for release_backend_clients.""" + +from absl.testing import absltest + +import jax +from jax.config import config +from jax._src import test_util as jtu +from jax._src.lib import xla_bridge as xb +from jax._src.lib import xla_client as xc + +config.parse_flags_with_absl() + + +class ClearBackendsTest(jtu.JaxTestCase): + + def test_clear_backends(self): + g = jax.jit(lambda x, y: x * y) + self.assertEqual(g(1, 2), 2) + if xc._version >= 79: + self.assertNotEmpty(xb.get_backend().live_executables()) + jax.clear_backends() + self.assertEmpty(xb.get_backend().live_executables()) + self.assertEqual(g(1, 2), 2) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())