[Mosaic GPU] Add a simple benchmark.

PiperOrigin-RevId: 639023867
This commit is contained in:
Christos Perivolaropoulos 2024-05-31 07:07:38 -07:00 committed by jax authors
parent d2a39bc61b
commit 8eaea2b13d
3 changed files with 155 additions and 3 deletions

56
benchmarks/mosaic/BUILD Normal file
View File

@ -0,0 +1,56 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"jax_generate_backend_suites",
"jax_test",
"py_deps",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
jax_generate_backend_suites()
DISABLED_BACKENDS = [
"cpu",
"tpu",
]
DISABLED_CONFIGS = [
"gpu",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
"gpu_x32",
"gpu_pjrt_c_api",
]
jax_test(
name = "matmul_bench",
srcs = ["matmul_bench.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
tags = ["notap"],
deps = [
"//third_party/py/google_benchmark",
"//third_party/py/jax:mosaic_gpu",
"//third_party/py/jax/experimental/mosaic/gpu/examples:matmul",
] + py_deps("absl/testing") + py_deps("numpy"),
)

View File

@ -0,0 +1,91 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for mosaic gpu matrix mutliplication."""
import functools
import sys
from absl import app
import google_benchmark as benchmark
from jax._src import config
from jax.experimental.mosaic.gpu.examples import matmul
from jax._src import test_util as jtu
import jax.numpy as jnp
config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
def _params_name(params):
return ",".join(f"{k}={v}" for k, v in params.items())
def matmul_benchmark(*args):
def decorator(get_runtimes):
for test_case in args:
@benchmark.register(name=f"{get_runtimes.__name__}_{_params_name(test_case)}")
@benchmark.option.unit(benchmark.kMillisecond)
@benchmark.option.use_manual_time()
@benchmark.option.iterations(1)
@functools.wraps(get_runtimes)
def wrapper(state, test_case=test_case):
m, n, k = test_case["m"], test_case["n"], test_case["k"]
runtime, ref_runtime = get_runtimes(state, **test_case)
state.counters["TFlops"] = (
float(2 * k * m * n) / (runtime / 1e3) / 1e12
)
state.counters["jax_TFlops"] = (
float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
)
state.counters["speedup"] = ref_runtime / runtime
state.set_iteration_time(runtime / 1e3)
return wrapper
return decorator
@matmul_benchmark(
dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128),
dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128),
dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64),
dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64),
)
def bf16_i8_matmul(self, m, k, n, stages, tile_m):
# RHS.element_size==1b so k_tile=128
if stages * 128 > k:
raise ValueError(f"Too many stages {(stages, k)=}.")
return matmul.verify(
m,
k,
n,
stages,
tile_m=tile_m,
rhs_transpose=False,
lhs_dtype=jnp.bfloat16,
rhs_dtype=jnp.int8,
)
def main(_):
device = jtu.device_under_test()
if device != "gpu":
raise ValueError(f"Mosaic only work with gpu (got {device})")
benchmark.run_benchmarks()
if __name__ == "__main__":
sys.argv = benchmark.initialize(sys.argv)
app.run(main)

View File

@ -16,6 +16,7 @@
import dataclasses
import enum
import functools
import jax
from jax import random
@ -508,12 +509,16 @@ def verify(
jax.lax.reduce_precision(v, exponent_bits, mantissa_bits)
for v in (x, y)
)
ref = jax.lax.dot_general(
x, y, dimension_numbers,
ref_f = functools.partial(
jax.lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=jnp.float32,
)
ref, ref_runtime = profiler.measure(ref_f, x, y)
np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3)
return runtime
return runtime, ref_runtime
if __name__ == "__main__":