mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix the mesh_axes of benchmarks to be Pspecs
PiperOrigin-RevId: 437932954
This commit is contained in:
parent
d69c7b3c21
commit
a68b0f3a0a
@ -19,19 +19,20 @@ import google_benchmark
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
from jax.experimental import global_device_array as gda
|
||||
import numpy as np
|
||||
|
||||
mesh_shapes_axes = [
|
||||
((256, 8), ["x", "y"]),
|
||||
((256, 8), [None]),
|
||||
((256, 8), ["x"]),
|
||||
((256, 8), ["y"]),
|
||||
((256, 8), [("x", "y")]),
|
||||
((128, 8), ["x", "y"]),
|
||||
((4, 2), ["x", "y"]),
|
||||
((16, 4), ["x", "y"]),
|
||||
((16, 4), [("x", "y")]),
|
||||
((256, 8), P("x", "y")),
|
||||
((256, 8), P(None)),
|
||||
((256, 8), P("x")),
|
||||
((256, 8), P("y")),
|
||||
((256, 8), P(("x", "y"))),
|
||||
((128, 8), P("x", "y")),
|
||||
((4, 2), P("x", "y")),
|
||||
((16, 4), P("x", "y")),
|
||||
((16, 4), P(("x", "y"))),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user