mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add concurrent execution support in colocated Python
This change makes asynchronous execution run without holding a mutex. This allows colocated Python executions from multiple Python threads to run concurrently. PiperOrigin-RevId: 704340663
This commit is contained in:
parent
d908e0add9
commit
296d1670bf
@ -343,7 +343,9 @@ def _get_specialized_func(
|
||||
async_execution_func = _make_async_execution_fun(info, specialization)
|
||||
# Fall-through.
|
||||
|
||||
return async_execution_func(*args, **kwargs)
|
||||
# Asynchronous execution runs outside of the mutex to allow concurrent
|
||||
# execution for inline executors.
|
||||
return async_execution_func(*args, **kwargs)
|
||||
|
||||
return specialized_func
|
||||
|
||||
|
@ -13,9 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from typing import Sequence
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
@ -241,6 +244,84 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("on_main_thread", True),
|
||||
("on_non_main_thread", False),
|
||||
)
|
||||
def testSequentialExecution(self, on_main_thread: bool):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
# Make sure that this input array is ready for use by the colocated Python
|
||||
# function and does not disrupt elapsed time measurement.
|
||||
jax.block_until_ready(x)
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def sleep(x: jax.Array) -> jax.Array:
|
||||
time.sleep(5)
|
||||
return x
|
||||
|
||||
# Specify out_specs_fn so that all executions are asynchronously dispatched.
|
||||
sleep = sleep.specialize(out_specs_fn=lambda x: x)
|
||||
|
||||
def sleep_twice_and_wait(x: jax.Array) -> None:
|
||||
_ = sleep(x)
|
||||
jax.block_until_ready(sleep(x))
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Two executions of `sleep` within `sleep_twice_and_wait` should run
|
||||
# sequentially.
|
||||
if on_main_thread:
|
||||
sleep_twice_and_wait(x)
|
||||
else:
|
||||
t = threading.Thread(target=sleep_twice_and_wait, args=(x,))
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# If sequential execution did not happen, elapsed time typically will be
|
||||
# around 5 seconds.
|
||||
self.assertGreaterEqual(elapsed_time, 10)
|
||||
|
||||
def testConcurrentExecution(self):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
# Make sure that this input array is ready for use by the colocated Python
|
||||
# function and does not disrupt elapsed time measurement.
|
||||
jax.block_until_ready(x)
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def sleep(x: jax.Array) -> jax.Array:
|
||||
time.sleep(5)
|
||||
return x
|
||||
|
||||
# Specify out_specs_fn so that all executions are asynchronously dispatched.
|
||||
sleep = sleep.specialize(out_specs_fn=lambda x: x)
|
||||
|
||||
def sleep_and_wait(x: jax.Array) -> None:
|
||||
jax.block_until_ready(sleep(x))
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# All three executions of `sleep_and_wait` should run concurrently.
|
||||
t1 = threading.Thread(target=sleep_and_wait, args=(x,))
|
||||
t2 = threading.Thread(target=sleep_and_wait, args=(x,))
|
||||
t1.start()
|
||||
t2.start()
|
||||
sleep_and_wait(x)
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
self.assertGreaterEqual(elapsed_time, 5)
|
||||
# If concurrent execution did not happen, elapsed time typically will be
|
||||
# around 15 seconds.
|
||||
self.assertLess(elapsed_time, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user