Allow uncommitted single device PyArray in C++ pjit path.

PiperOrigin-RevId: 481711690
This commit is contained in:
Kuangyuan Chen 2022-10-17 12:34:55 -07:00 committed by jax authors
parent 504b3c1b25
commit fd2f590b3b

View File

@ -696,6 +696,17 @@ def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@jax_config.jax_array(True)
def simple_dispatch_array_pjit(state):
a = jax.device_put(1)
b = jax.device_put(2)
f = pjit_lib.pjit(operator.add)
f(a, b)
while state:
f(a, b)
if __name__ == "__main__":
google_benchmark.main()