[Mosaic GPU] Add f32 benchmarks for matmul.

PiperOrigin-RevId: 640826101
This commit is contained in:
Christos Perivolaropoulos 2024-06-06 02:30:41 -07:00 committed by jax authors
parent 55d0f5ef8f
commit 24e4bf2265

View File

@ -40,7 +40,7 @@ def matmul_benchmark(*args):
@functools.wraps(get_runtimes)
def wrapper(state, test_case=test_case):
m, n, k = test_case["m"], test_case["n"], test_case["k"]
runtime, ref_runtime = get_runtimes(state, **test_case)
runtime, ref_runtime = get_runtimes(**test_case)
state.counters["TFlops"] = (
float(2 * k * m * n) / (runtime / 1e3) / 1e12
)
@ -50,8 +50,6 @@ def matmul_benchmark(*args):
state.counters["speedup"] = ref_runtime / runtime
state.set_iteration_time(runtime / 1e3)
return wrapper
return decorator
@ -61,7 +59,7 @@ def matmul_benchmark(*args):
dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64),
dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64),
)
def bf16_i8_matmul(self, m, k, n, stages, tile_m):
def bf16_i8_matmul(m, k, n, stages, tile_m):
# RHS.element_size==1b so k_tile=128
if stages * 128 > k:
raise ValueError(f"Too many stages {(stages, k)=}.")
@ -77,6 +75,27 @@ def bf16_i8_matmul(self, m, k, n, stages, tile_m):
rhs_dtype=jnp.int8,
)
@matmul_benchmark(
dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=256),
dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=128),
dict(m=1024, n=1024, k=1024, stages=4, tile_m=64, tile_n=128),
)
def f32_matmul(m, n, k, stages, tile_m, tile_n):
if stages * 32 > k:
raise ValueError(f"Too many stages {(stages, k)=}.")
return matmul.verify(
m=m,
k=k,
n=n,
stages=stages,
tile_m=tile_m,
tile_n=tile_n,
rhs_transpose=True,
lhs_dtype=jnp.float32,
rhs_dtype=jnp.float32,
)
def main(_):
device = jtu.device_under_test()