diff --git a/benchmarks/math_benchmark.py b/benchmarks/math_benchmark.py index 2dc95263f..4a0a0b7a6 100644 --- a/benchmarks/math_benchmark.py +++ b/benchmarks/math_benchmark.py @@ -72,6 +72,61 @@ def jax_unary(state, **kwargs): input0.size * state.iterations, Counter.kIsRate ) +@math_benchmark( + [ + { + 'name': f'{op.__name__}_{mkn[0]}x{mkn[1]}x{mkn[2]}_{dtype}', + 'mkn': mkn, + 'dtype': dtype, + 'op': op, + } + for op in [ + jnp.dot, + ] + for mkn in [[2**i, 2**i, 2**i] for i in range(4, 11, 1)] + + [ + [1, 2, 256], + [1, 8, 256], + [1, 18, 300], + [1, 37, 256], + [1, 91, 256], + [1, 111, 256], + [1, 192, 192], + [1, 226, 256], + [1, 256, 192], + [1, 256, 256], + [1, 512, 512], + [1, 300, 18], + [21, 24, 1], + [21, 120, 1], + [10, 10, 10], + [100, 100, 100], + [18, 1, 300], + [18, 300, 1], + [300, 1, 18], + [300, 18, 1], + ] + for dtype in ['float32'] + ] +) +def jax_binary_op(state, **kwargs): + mkn = kwargs['mkn'] + m = mkn[0] + k = mkn[1] + n = mkn[2] + dtype = kwargs['dtype'] + op = kwargs['op'] + a = np.random.random([m, k]).astype(dtype) + b = np.random.random([k, n]).astype(dtype) + f = op + f_jitted = jax.jit(f) + f_jitted(a, b).block_until_ready() + while state: + f_jitted(a, b).block_until_ready() + state.counters['items_per_second'] = Counter( + state.iterations, Counter.kIsRate + ) + if __name__ == '__main__': benchmark.main()