mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a benchmark with many arguments.
PiperOrigin-RevId: 393216026
This commit is contained in:
parent
26c9671413
commit
6cb8737c1a
@ -266,6 +266,33 @@ def pmap_simple_8_devices(state):
|
||||
d.block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@required_devices(8)
|
||||
def pmap_simple_dispatch_8_devices_100_args(state):
|
||||
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
|
||||
args = []
|
||||
for i in range(100):
|
||||
args.append(jnp.array(list(range(i, i+8))))
|
||||
|
||||
args = f(*args)
|
||||
|
||||
while state:
|
||||
args = f(*args)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@required_devices(8)
|
||||
def pmap_simple_8_devices_100_args(state):
|
||||
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
|
||||
args = []
|
||||
for i in range(100):
|
||||
args.append(jnp.array(list(range(i, i+8))))
|
||||
|
||||
while state:
|
||||
out = f(*args)
|
||||
jax.tree_map(lambda x: x.block_until_ready(), out)
|
||||
|
||||
|
||||
def _run_sda_index_bench(state, num_devices):
|
||||
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
|
||||
jax.device_get(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user