diff --git a/benchmarks/mosaic/matmul_bench.py b/benchmarks/mosaic/matmul_bench.py index e50e9898d..32c147916 100644 --- a/benchmarks/mosaic/matmul_bench.py +++ b/benchmarks/mosaic/matmul_bench.py @@ -40,7 +40,7 @@ def matmul_benchmark(*args): @functools.wraps(get_runtimes) def wrapper(state, test_case=test_case): m, n, k = test_case["m"], test_case["n"], test_case["k"] - runtime, ref_runtime = get_runtimes(state, **test_case) + runtime, ref_runtime = get_runtimes(**test_case) state.counters["TFlops"] = ( float(2 * k * m * n) / (runtime / 1e3) / 1e12 ) @@ -50,8 +50,6 @@ def matmul_benchmark(*args): state.counters["speedup"] = ref_runtime / runtime state.set_iteration_time(runtime / 1e3) - return wrapper - return decorator @@ -61,7 +59,7 @@ def matmul_benchmark(*args): dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64), dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64), ) -def bf16_i8_matmul(self, m, k, n, stages, tile_m): +def bf16_i8_matmul(m, k, n, stages, tile_m): # RHS.element_size==1b so k_tile=128 if stages * 128 > k: raise ValueError(f"Too many stages {(stages, k)=}.") @@ -77,6 +75,27 @@ def bf16_i8_matmul(self, m, k, n, stages, tile_m): rhs_dtype=jnp.int8, ) +@matmul_benchmark( + dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=256), + dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=128), + dict(m=1024, n=1024, k=1024, stages=4, tile_m=64, tile_n=128), +) +def f32_matmul(m, n, k, stages, tile_m, tile_n): + if stages * 32 > k: + raise ValueError(f"Too many stages {(stages, k)=}.") + + return matmul.verify( + m=m, + k=k, + n=n, + stages=stages, + tile_m=tile_m, + tile_n=tile_n, + rhs_transpose=True, + lhs_dtype=jnp.float32, + rhs_dtype=jnp.float32, + ) + def main(_): device = jtu.device_under_test()