math_benchmark: add dot op

PiperOrigin-RevId: 515408666
This commit is contained in:
Emilio Cota 2023-03-09 12:24:09 -08:00 committed by jax authors
parent f1f4840a0d
commit 845d68b39e

View File

@ -72,6 +72,61 @@ def jax_unary(state, **kwargs):
input0.size * state.iterations, Counter.kIsRate
)
@math_benchmark(
[
{
'name': f'{op.__name__}_{mkn[0]}x{mkn[1]}x{mkn[2]}_{dtype}',
'mkn': mkn,
'dtype': dtype,
'op': op,
}
for op in [
jnp.dot,
]
for mkn in [[2**i, 2**i, 2**i] for i in range(4, 11, 1)] +
[
[1, 2, 256],
[1, 8, 256],
[1, 18, 300],
[1, 37, 256],
[1, 91, 256],
[1, 111, 256],
[1, 192, 192],
[1, 226, 256],
[1, 256, 192],
[1, 256, 256],
[1, 512, 512],
[1, 300, 18],
[21, 24, 1],
[21, 120, 1],
[10, 10, 10],
[100, 100, 100],
[18, 1, 300],
[18, 300, 1],
[300, 1, 18],
[300, 18, 1],
]
for dtype in ['float32']
]
)
def jax_binary_op(state, **kwargs):
mkn = kwargs['mkn']
m = mkn[0]
k = mkn[1]
n = mkn[2]
dtype = kwargs['dtype']
op = kwargs['op']
a = np.random.random([m, k]).astype(dtype)
b = np.random.random([k, n]).astype(dtype)
f = op
f_jitted = jax.jit(f)
f_jitted(a, b).block_until_ready()
while state:
f_jitted(a, b).block_until_ready()
state.counters['items_per_second'] = Counter(
state.iterations, Counter.kIsRate
)
if __name__ == '__main__':
benchmark.main()