[JAX] Add a test verifying the behavior of module-level state accessed by colocated Python

A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling

An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.

PiperOrigin-RevId: 723595727
This commit is contained in:
Hyeontaek Lim 2025-02-05 11:48:25 -08:00 committed by jax authors
parent 10363663e8
commit f43d2b68d9

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import threading
import time
from typing import Sequence
@ -56,7 +57,8 @@ def _colocated_cpu_devices(
_count_colocated_python_specialization_cache_miss = jtu.count_events(
"colocated_python_func._get_specialized_func")
"colocated_python_func._get_specialized_func"
)
class ColocatedPythonTest(jtu.JaxTestCase):
@ -335,6 +337,47 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = jax.device_get(out)
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))
def testModuleVariableAccess(self):
try:
# The following pattern of storing and accessing non-serialized state in
# the Python module is discouraged for storing user-defined state.
# However, it should still work because many caching mechanisms rely on
# this behavior.
# Poison the test's own `colocated_python` module with a non-serializable
# object (file) to detect any invalid attempt to serialize the module as
# part of a colocated Python function.
colocated_python._testing_non_serializable_object = (
tempfile.TemporaryFile()
)
@colocated_python.colocated_python
def set_global_state(x: jax.Array) -> jax.Array:
colocated_python._testing_global_state = x
return x + 1
@colocated_python.colocated_python
def get_global_state(x: jax.Array) -> jax.Array:
del x
return colocated_python._testing_global_state
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
y = np.array(2)
y = jax.device_put(y, cpu_devices[0])
jax.block_until_ready(set_global_state(x))
out = jax.device_get(get_global_state(y))
np.testing.assert_equal(out, np.array(1))
finally:
if "_testing_non_serializable_object" in colocated_python.__dict__:
colocated_python._testing_non_serializable_object.close()
del colocated_python._testing_non_serializable_object
if "_testing_global_state" in colocated_python.__dict__:
del colocated_python._testing_global_state
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())