diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index a1486a7d8..8e1d676a7 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import threading import time from typing import Sequence @@ -56,7 +57,8 @@ def _colocated_cpu_devices( _count_colocated_python_specialization_cache_miss = jtu.count_events( - "colocated_python_func._get_specialized_func") + "colocated_python_func._get_specialized_func" +) class ColocatedPythonTest(jtu.JaxTestCase): @@ -335,6 +337,47 @@ class ColocatedPythonTest(jtu.JaxTestCase): out = jax.device_get(out) np.testing.assert_equal(out, np.array([2 + 4, 0 + 8])) + def testModuleVariableAccess(self): + try: + # The following pattern of storing and accessing non-serialized state in + # the Python module is discouraged for storing user-defined state. + # However, it should still work because many caching mechanisms rely on + # this behavior. + + # Poison the test's own `colocated_python` module with a non-serializable + # object (file) to detect any invalid attempt to serialize the module as + # part of a colocated Python function. + colocated_python._testing_non_serializable_object = ( + tempfile.TemporaryFile() + ) + + @colocated_python.colocated_python + def set_global_state(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = x + return x + 1 + + @colocated_python.colocated_python + def get_global_state(x: jax.Array) -> jax.Array: + del x + return colocated_python._testing_global_state + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + y = np.array(2) + y = jax.device_put(y, cpu_devices[0]) + + jax.block_until_ready(set_global_state(x)) + out = jax.device_get(get_global_state(y)) + + np.testing.assert_equal(out, np.array(1)) + finally: + if "_testing_non_serializable_object" in colocated_python.__dict__: + colocated_python._testing_non_serializable_object.close() + del colocated_python._testing_non_serializable_object + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())