[JAX] Add concurrent execution support in colocated Python

This change makes asynchronous execution run without holding a mutex. This
allows colocated Python executions from multiple Python threads to run
concurrently.

PiperOrigin-RevId: 704340663
This commit is contained in:
Hyeontaek Lim 2024-12-09 10:42:52 -08:00 committed by jax authors
parent d908e0add9
commit 296d1670bf
2 changed files with 84 additions and 1 deletions

View File

@ -343,7 +343,9 @@ def _get_specialized_func(
async_execution_func = _make_async_execution_fun(info, specialization)
# Fall-through.
return async_execution_func(*args, **kwargs)
# Asynchronous execution runs outside of the mutex to allow concurrent
# execution for inline executors.
return async_execution_func(*args, **kwargs)
return specialized_func

View File

@ -13,9 +13,12 @@
# limitations under the License.
import contextlib
import threading
import time
from typing import Sequence
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
@ -241,6 +244,84 @@ class ColocatedPythonTest(jtu.JaxTestCase):
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
@parameterized.named_parameters(
("on_main_thread", True),
("on_non_main_thread", False),
)
def testSequentialExecution(self, on_main_thread: bool):
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
# Make sure that this input array is ready for use by the colocated Python
# function and does not disrupt elapsed time measurement.
jax.block_until_ready(x)
@colocated_python.colocated_python
def sleep(x: jax.Array) -> jax.Array:
time.sleep(5)
return x
# Specify out_specs_fn so that all executions are asynchronously dispatched.
sleep = sleep.specialize(out_specs_fn=lambda x: x)
def sleep_twice_and_wait(x: jax.Array) -> None:
_ = sleep(x)
jax.block_until_ready(sleep(x))
start_time = time.time()
# Two executions of `sleep` within `sleep_twice_and_wait` should run
# sequentially.
if on_main_thread:
sleep_twice_and_wait(x)
else:
t = threading.Thread(target=sleep_twice_and_wait, args=(x,))
t.start()
t.join()
elapsed_time = time.time() - start_time
# If sequential execution did not happen, elapsed time typically will be
# around 5 seconds.
self.assertGreaterEqual(elapsed_time, 10)
def testConcurrentExecution(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
# Make sure that this input array is ready for use by the colocated Python
# function and does not disrupt elapsed time measurement.
jax.block_until_ready(x)
@colocated_python.colocated_python
def sleep(x: jax.Array) -> jax.Array:
time.sleep(5)
return x
# Specify out_specs_fn so that all executions are asynchronously dispatched.
sleep = sleep.specialize(out_specs_fn=lambda x: x)
def sleep_and_wait(x: jax.Array) -> None:
jax.block_until_ready(sleep(x))
start_time = time.time()
# All three executions of `sleep_and_wait` should run concurrently.
t1 = threading.Thread(target=sleep_and_wait, args=(x,))
t2 = threading.Thread(target=sleep_and_wait, args=(x,))
t1.start()
t2.start()
sleep_and_wait(x)
t1.join()
t2.join()
elapsed_time = time.time() - start_time
self.assertGreaterEqual(elapsed_time, 5)
# If concurrent execution did not happen, elapsed time typically will be
# around 15 seconds.
self.assertLess(elapsed_time, 10)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())