mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile. PiperOrigin-RevId: 714037277
70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for garbage allocation guard."""
|
|
|
|
import gc
|
|
import weakref
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax._src import config
|
|
import jax._src.test_util as jtu
|
|
import jax.numpy as jnp
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
def _create_array_cycle():
|
|
"""Creates a reference cycle of two jax.Arrays."""
|
|
n1 = jnp.ones((2, 2))
|
|
n2 = jnp.zeros((2, 2))
|
|
n1.next = n2
|
|
n2.next = n1
|
|
return weakref.ref(n1)
|
|
|
|
|
|
@jtu.thread_unsafe_test_class() # GC isn't predictable when threaded.
|
|
class GarbageCollectionGuardTest(jtu.JaxTestCase):
|
|
|
|
def test_gced_array_is_not_logged_by_default(self):
|
|
# Create a reference cycle of two jax.Arrays.
|
|
ref = _create_array_cycle()
|
|
with jtu.capture_stderr() as stderr:
|
|
self.assertIsNotNone(ref()) # Cycle still alive.
|
|
gc.collect()
|
|
self.assertIsNone(ref()) # Cycle collected.
|
|
# Check that no error message is logged because
|
|
# `array_garbage_collection_guard` defaults to `allow`.
|
|
self.assertNotIn(
|
|
"`jax.Array` was deleted by the Python garbage collector", stderr(),
|
|
)
|
|
|
|
def test_gced_array_is_logged(self):
|
|
with config.array_garbage_collection_guard("log"):
|
|
with jtu.capture_stderr() as stderr:
|
|
# Create a reference cycle of two jax.Arrays.
|
|
ref = _create_array_cycle()
|
|
self.assertIsNotNone(ref()) # Cycle still alive.
|
|
gc.collect()
|
|
self.assertIsNone(ref()) # Cycle collected.
|
|
# Verify that an error message is logged because two jax.Arrays were garbage
|
|
# collected.
|
|
self.assertIn(
|
|
"`jax.Array` was deleted by the Python garbage collector", stderr()
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|