[Rollback]

Allow uncommitted single device PyArray in C++ pjit path.

PiperOrigin-RevId: 482084898
This commit is contained in:
Yash Katariya 2022-10-18 19:41:42 -07:00 committed by jax authors
parent 807269990e
commit 3572bb2db0

View File

@ -699,17 +699,6 @@ def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@jax_config.jax_array(True)
def simple_dispatch_array_pjit(state):
a = jax.device_put(1)
b = jax.device_put(2)
f = pjit_lib.pjit(operator.add)
f(a, b)
while state:
f(a, b)
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])