mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add a test verifying the behavior of module-level state accessed by colocated Python
A new test verifies that * Python module-level variables can be created/set and read from a colocated Python function * Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling An API for defining user-defined state and accessing it from multiple colocated Python functions (i.e., object support) will be added later. That will be a recommended way to express user-defined state. The capability of accessing Python module variables is still crucial because a lot of Python code (including JAX) requires this behavior to implement caching. PiperOrigin-RevId: 723595727
This commit is contained in:
parent
10363663e8
commit
f43d2b68d9
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from typing import Sequence
|
||||
@ -56,7 +57,8 @@ def _colocated_cpu_devices(
|
||||
|
||||
|
||||
_count_colocated_python_specialization_cache_miss = jtu.count_events(
|
||||
"colocated_python_func._get_specialized_func")
|
||||
"colocated_python_func._get_specialized_func"
|
||||
)
|
||||
|
||||
|
||||
class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
@ -335,6 +337,47 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = jax.device_get(out)
|
||||
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))
|
||||
|
||||
def testModuleVariableAccess(self):
|
||||
try:
|
||||
# The following pattern of storing and accessing non-serialized state in
|
||||
# the Python module is discouraged for storing user-defined state.
|
||||
# However, it should still work because many caching mechanisms rely on
|
||||
# this behavior.
|
||||
|
||||
# Poison the test's own `colocated_python` module with a non-serializable
|
||||
# object (file) to detect any invalid attempt to serialize the module as
|
||||
# part of a colocated Python function.
|
||||
colocated_python._testing_non_serializable_object = (
|
||||
tempfile.TemporaryFile()
|
||||
)
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def set_global_state(x: jax.Array) -> jax.Array:
|
||||
colocated_python._testing_global_state = x
|
||||
return x + 1
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def get_global_state(x: jax.Array) -> jax.Array:
|
||||
del x
|
||||
return colocated_python._testing_global_state
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
y = np.array(2)
|
||||
y = jax.device_put(y, cpu_devices[0])
|
||||
|
||||
jax.block_until_ready(set_global_state(x))
|
||||
out = jax.device_get(get_global_state(y))
|
||||
|
||||
np.testing.assert_equal(out, np.array(1))
|
||||
finally:
|
||||
if "_testing_non_serializable_object" in colocated_python.__dict__:
|
||||
colocated_python._testing_non_serializable_object.close()
|
||||
del colocated_python._testing_non_serializable_object
|
||||
if "_testing_global_state" in colocated_python.__dict__:
|
||||
del colocated_python._testing_global_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user