mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic GPU] Add f32 benchmarks for matmul.
PiperOrigin-RevId: 640826101
This commit is contained in:
parent
55d0f5ef8f
commit
24e4bf2265
@ -40,7 +40,7 @@ def matmul_benchmark(*args):
|
||||
@functools.wraps(get_runtimes)
|
||||
def wrapper(state, test_case=test_case):
|
||||
m, n, k = test_case["m"], test_case["n"], test_case["k"]
|
||||
runtime, ref_runtime = get_runtimes(state, **test_case)
|
||||
runtime, ref_runtime = get_runtimes(**test_case)
|
||||
state.counters["TFlops"] = (
|
||||
float(2 * k * m * n) / (runtime / 1e3) / 1e12
|
||||
)
|
||||
@ -50,8 +50,6 @@ def matmul_benchmark(*args):
|
||||
state.counters["speedup"] = ref_runtime / runtime
|
||||
state.set_iteration_time(runtime / 1e3)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -61,7 +59,7 @@ def matmul_benchmark(*args):
|
||||
dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64),
|
||||
dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64),
|
||||
)
|
||||
def bf16_i8_matmul(self, m, k, n, stages, tile_m):
|
||||
def bf16_i8_matmul(m, k, n, stages, tile_m):
|
||||
# RHS.element_size==1b so k_tile=128
|
||||
if stages * 128 > k:
|
||||
raise ValueError(f"Too many stages {(stages, k)=}.")
|
||||
@ -77,6 +75,27 @@ def bf16_i8_matmul(self, m, k, n, stages, tile_m):
|
||||
rhs_dtype=jnp.int8,
|
||||
)
|
||||
|
||||
@matmul_benchmark(
|
||||
dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=256),
|
||||
dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=128),
|
||||
dict(m=1024, n=1024, k=1024, stages=4, tile_m=64, tile_n=128),
|
||||
)
|
||||
def f32_matmul(m, n, k, stages, tile_m, tile_n):
|
||||
if stages * 32 > k:
|
||||
raise ValueError(f"Too many stages {(stages, k)=}.")
|
||||
|
||||
return matmul.verify(
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
stages=stages,
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
rhs_transpose=True,
|
||||
lhs_dtype=jnp.float32,
|
||||
rhs_dtype=jnp.float32,
|
||||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
device = jtu.device_under_test()
|
||||
|
Loading…
x
Reference in New Issue
Block a user