mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic GPU] Add a simple benchmark.
PiperOrigin-RevId: 639023867
This commit is contained in:
parent
d2a39bc61b
commit
8eaea2b13d
56
benchmarks/mosaic/BUILD
Normal file
56
benchmarks/mosaic/BUILD
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
DISABLED_BACKENDS = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
]
|
||||
|
||||
DISABLED_CONFIGS = [
|
||||
"gpu",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
]
|
||||
|
||||
jax_test(
|
||||
name = "matmul_bench",
|
||||
srcs = ["matmul_bench.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//third_party/py/google_benchmark",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
"//third_party/py/jax/experimental/mosaic/gpu/examples:matmul",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
91
benchmarks/mosaic/matmul_bench.py
Normal file
91
benchmarks/mosaic/matmul_bench.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Microbenchmarks for mosaic gpu matrix mutliplication."""
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
import google_benchmark as benchmark
|
||||
from jax._src import config
|
||||
from jax.experimental.mosaic.gpu.examples import matmul
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _params_name(params):
|
||||
return ",".join(f"{k}={v}" for k, v in params.items())
|
||||
|
||||
def matmul_benchmark(*args):
|
||||
def decorator(get_runtimes):
|
||||
for test_case in args:
|
||||
|
||||
@benchmark.register(name=f"{get_runtimes.__name__}_{_params_name(test_case)}")
|
||||
@benchmark.option.unit(benchmark.kMillisecond)
|
||||
@benchmark.option.use_manual_time()
|
||||
@benchmark.option.iterations(1)
|
||||
@functools.wraps(get_runtimes)
|
||||
def wrapper(state, test_case=test_case):
|
||||
m, n, k = test_case["m"], test_case["n"], test_case["k"]
|
||||
runtime, ref_runtime = get_runtimes(state, **test_case)
|
||||
state.counters["TFlops"] = (
|
||||
float(2 * k * m * n) / (runtime / 1e3) / 1e12
|
||||
)
|
||||
state.counters["jax_TFlops"] = (
|
||||
float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
|
||||
)
|
||||
state.counters["speedup"] = ref_runtime / runtime
|
||||
state.set_iteration_time(runtime / 1e3)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@matmul_benchmark(
|
||||
dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128),
|
||||
dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128),
|
||||
dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64),
|
||||
dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64),
|
||||
)
|
||||
def bf16_i8_matmul(self, m, k, n, stages, tile_m):
|
||||
# RHS.element_size==1b so k_tile=128
|
||||
if stages * 128 > k:
|
||||
raise ValueError(f"Too many stages {(stages, k)=}.")
|
||||
|
||||
return matmul.verify(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
stages,
|
||||
tile_m=tile_m,
|
||||
rhs_transpose=False,
|
||||
lhs_dtype=jnp.bfloat16,
|
||||
rhs_dtype=jnp.int8,
|
||||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
device = jtu.device_under_test()
|
||||
if device != "gpu":
|
||||
raise ValueError(f"Mosaic only work with gpu (got {device})")
|
||||
|
||||
benchmark.run_benchmarks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.argv = benchmark.initialize(sys.argv)
|
||||
app.run(main)
|
@ -16,6 +16,7 @@
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import random
|
||||
@ -508,12 +509,16 @@ def verify(
|
||||
jax.lax.reduce_precision(v, exponent_bits, mantissa_bits)
|
||||
for v in (x, y)
|
||||
)
|
||||
ref = jax.lax.dot_general(
|
||||
x, y, dimension_numbers,
|
||||
|
||||
ref_f = functools.partial(
|
||||
jax.lax.dot_general,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
)
|
||||
|
||||
ref, ref_runtime = profiler.measure(ref_f, x, y)
|
||||
np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3)
|
||||
return runtime
|
||||
return runtime, ref_runtime
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user