mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 481711690
This commit is contained in:
parent
504b3c1b25
commit
fd2f590b3b
@ -696,6 +696,17 @@ def pjit_simple_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
@google_benchmark.register
|
||||
@jax_config.jax_array(True)
|
||||
def simple_dispatch_array_pjit(state):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
f = pjit_lib.pjit(operator.add)
|
||||
f(a, b)
|
||||
|
||||
while state:
|
||||
f(a, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user