mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
math_benchmark: add dot op
PiperOrigin-RevId: 515408666
This commit is contained in:
parent
f1f4840a0d
commit
845d68b39e
@ -72,6 +72,61 @@ def jax_unary(state, **kwargs):
|
||||
input0.size * state.iterations, Counter.kIsRate
|
||||
)
|
||||
|
||||
@math_benchmark(
|
||||
[
|
||||
{
|
||||
'name': f'{op.__name__}_{mkn[0]}x{mkn[1]}x{mkn[2]}_{dtype}',
|
||||
'mkn': mkn,
|
||||
'dtype': dtype,
|
||||
'op': op,
|
||||
}
|
||||
for op in [
|
||||
jnp.dot,
|
||||
]
|
||||
for mkn in [[2**i, 2**i, 2**i] for i in range(4, 11, 1)] +
|
||||
[
|
||||
[1, 2, 256],
|
||||
[1, 8, 256],
|
||||
[1, 18, 300],
|
||||
[1, 37, 256],
|
||||
[1, 91, 256],
|
||||
[1, 111, 256],
|
||||
[1, 192, 192],
|
||||
[1, 226, 256],
|
||||
[1, 256, 192],
|
||||
[1, 256, 256],
|
||||
[1, 512, 512],
|
||||
[1, 300, 18],
|
||||
[21, 24, 1],
|
||||
[21, 120, 1],
|
||||
[10, 10, 10],
|
||||
[100, 100, 100],
|
||||
[18, 1, 300],
|
||||
[18, 300, 1],
|
||||
[300, 1, 18],
|
||||
[300, 18, 1],
|
||||
]
|
||||
for dtype in ['float32']
|
||||
]
|
||||
)
|
||||
def jax_binary_op(state, **kwargs):
|
||||
mkn = kwargs['mkn']
|
||||
m = mkn[0]
|
||||
k = mkn[1]
|
||||
n = mkn[2]
|
||||
dtype = kwargs['dtype']
|
||||
op = kwargs['op']
|
||||
a = np.random.random([m, k]).astype(dtype)
|
||||
b = np.random.random([k, n]).astype(dtype)
|
||||
f = op
|
||||
f_jitted = jax.jit(f)
|
||||
f_jitted(a, b).block_until_ready()
|
||||
while state:
|
||||
f_jitted(a, b).block_until_ready()
|
||||
state.counters['items_per_second'] = Counter(
|
||||
state.iterations, Counter.kIsRate
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
benchmark.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user