mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Rollback]
Allow uncommitted single device PyArray in C++ pjit path. PiperOrigin-RevId: 482084898
This commit is contained in:
parent
807269990e
commit
3572bb2db0
@ -699,17 +699,6 @@ def pjit_simple_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
@google_benchmark.register
|
||||
@jax_config.jax_array(True)
|
||||
def simple_dispatch_array_pjit(state):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
f = pjit_lib.pjit(operator.add)
|
||||
f(a, b)
|
||||
|
||||
while state:
|
||||
f(a, b)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user