[Mosaic GPU] Add CUPTI profiler alongside events-based implementation

This commit is contained in:
Andrey Portnoy 2024-11-19 22:48:35 -05:00
parent 12b45b3235
commit cc22334c21
8 changed files with 301 additions and 20 deletions

View File

@ -576,7 +576,7 @@ def benchmark_and_verify(
head_dim=head_dim,
**kwargs,
)
out, runtime = profiler.measure(f, q[0], k[0], v[0])
out, runtime = profiler.measure(f)(q[0], k[0], v[0])
out = out[None]
@jax.jit

View File

@ -360,7 +360,7 @@ def verify(
wgmma_impl=WGMMADefaultImpl,
profiler_spec=prof_spec,
)
z, runtime = profiler.measure(f, x, y)
z, runtime = profiler.measure(f)(x, y)
if rhs_transpose:
dimension_numbers = ((1,), (1,)), ((), ())
@ -382,7 +382,7 @@ def verify(
preferred_element_type=out_dtype,
).astype(out_dtype)
ref, ref_runtime = profiler.measure(ref_f, x, y)
ref, ref_runtime = profiler.measure(ref_f)(x, y)
np.testing.assert_allclose(
z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
)
@ -426,7 +426,7 @@ if __name__ == "__main__":
f = build_kernel(
m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs
)
_, runtime = profiler.measure(f, x, y)
_, runtime = profiler.measure(f)(x, y)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise

View File

@ -69,20 +69,9 @@ def _event_elapsed(start_event, end_event):
)(start_event, end_event)
def measure(
def _measure_events(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> tuple[T, float]:
"""Measures the time it takes to execute the function on the GPU.
Args:
f: The function to measure. It must accept at least one argument and return
at least one output to be measurable.
*args: The arguments to pass to ``f``.
**kwargs: The keyword arguments to pass to ``f``.
Returns:
The return value of ``f`` and the elapsed time in milliseconds.
"""
if not has_registrations:
raise RuntimeError(
"This function requires jaxlib >=0.4.36 with CUDA support."
@ -109,6 +98,91 @@ def measure(
return outs, float(elapsed)
def _measure_cupti(f, aggregate):
def wrapper(*args, **kwargs):
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init()
try:
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs))
finally:
timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings()
if not timings:
return results, None
elif aggregate:
return results, sum(item[1] for item in timings)
else:
return results, timings
return wrapper
def measure(f: Callable, *, mode: str = "cupti", aggregate: bool = True
) -> Callable:
"""Sets up a function ``f`` for profiling on GPU.
``measure`` is a higher-order function that augments the argument ``f`` to
return GPU runtime in milliseconds, in addition to its proper outputs.
Args:
f: The function to measure. It must accept at least one argument and return
at least one output to be measurable.
mode: The mode of operation. Possible values are:
- "cupti", for CUPTI-based profiling.
- "events", for CUDA events-based profiling.
The two modes use different measurement methodologies and should not be
treated as interchangeable backends. See the Notes section for important
discussion.
aggregate: Whether to report an aggregate runtime. When ``False`` (only
supported by ``mode="cupti"``), the per-kernel timings are returned as a
list of tuples ``(<kernel name>, <runtime in ms>)``.
Returns:
A new function ``g`` that returns the measured GPU runtime as its last
additional output. Otherwise ``g`` accepts the same inputs and returns the
same outputs as ``f``.
Notes:
`CUPTI (CUDA Profiling Tools Interface)
<https://docs.nvidia.com/cupti/index.html>`_ is a high-accuracy,
high-precision profiling and tracing API, used in particular by Nsight
Systems and Nsight Compute. When using ``measure`` with ``mode="cupti"``,
device (GPU) execution runtimes are recorded for each kernel launched
during the execution of the function. In that mode, setting
``aggregate=True`` will sum the individual kernel runtimes to arrive at an
aggregate measurement. The "gaps" between the kernels when the device is
idle are not included in the aggregate.
The CUPTI API only allows a single "subscriber". This means that the
CUPTI-based profiler will fail when the program is run using tools that
make use of CUPTI, such as CUDA-GDB, Compute Sanitizer, Nsight Systems, or
Nsight Compute.
``mode="events"`` uses a different approach: a CUDA event is recorded
before and after the function ``f`` is executed. The reported runtime is
the time elapsed between the two events. In particular, included in the
measurement are:
- any potential "gaps" between the kernels when the device is idle
- any potential "gaps" between the "before" event and the start of the
first kernel, or between the end of the last kernel and the "after" event
In an attempt to minimize the second effect, internally the events-based
implementation may execute ``f`` more than once to "warm up" and exclude
compilation time from the measurement.
"""
match mode:
case "cupti":
return _measure_cupti(f, aggregate)
case "events":
if not aggregate:
raise ValueError(f"{aggregate=} is not supported with {mode=}")
def measure_events_wrapper(*args, **kwargs):
return _measure_events(f, *args, **kwargs)
return measure_events_wrapper
case _:
raise ValueError(f"Unrecognized profiler mode {mode}")
class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31

View File

@ -274,7 +274,7 @@ def main(unused_argv):
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
if seq_len < 32768:
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)

View File

@ -13,11 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <cstddef>
#include <cstdint>
#include <new>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>
#include "nanobind/nanobind.h"
#include "nanobind/stl/tuple.h"
#include "nanobind/stl/vector.h"
#include "absl/cleanup/cleanup.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
@ -118,6 +124,75 @@ XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) {
return kEventElapsed->Call(call_frame);
}
#define THROW(...) \
do { \
throw std::runtime_error( \
absl::StrCat("Mosaic GPU profiler error: ", __VA_ARGS__)); \
} while (0)
#define THROW_IF(expr, ...) \
do { \
if (expr) THROW(__VA_ARGS__); \
} while (0)
#define THROW_IF_CUPTI_ERROR(expr, ...) \
do { \
CUptiResult _result = (expr); \
if (_result != CUPTI_SUCCESS) { \
const char* s; \
cuptiGetErrorMessage(_result, &s); \
THROW(s, ": " __VA_OPT__(, ) __VA_ARGS__); \
} \
} while (0)
// CUPTI can only have one subscriber per process, so it's ok to make the
// profiler state global.
struct {
CUpti_SubscriberHandle subscriber;
std::vector<std::tuple<const char* /*kernel_name*/, double /*ms*/>> timings;
} profiler_state;
void callback_request(uint8_t** buffer, size_t* size, size_t* maxNumRecords) {
// 10 MiB buffer size is generous but somewhat arbitrary, it's at the upper
// bound of what's recommended in CUPTI documentation:
// https://docs.nvidia.com/cupti/main/main.html#cupti-callback-api:~:text=For%20typical%20workloads%2C%20it%E2%80%99s%20suggested%20to%20choose%20a%20size%20between%201%20and%2010%20MB.
const int buffer_size = 10 * (1 << 20);
// 8 byte alignment is specified in the official CUPTI code samples, see
// extras/CUPTI/samples/common/helper_cupti_activity.h in your CUDA
// installation.
*buffer = new (std::align_val_t(8)) uint8_t[buffer_size];
*size = buffer_size;
*maxNumRecords = 0;
}
void callback_complete(CUcontext context, uint32_t streamId,
uint8_t* buffer, size_t size, size_t validSize) {
// take ownership of the buffer once CUPTI is done using it
absl::Cleanup cleanup = [buffer]() {
operator delete[](buffer, std::align_val_t(8));
};
CUpti_Activity* record = nullptr;
while (true) {
CUptiResult status = cuptiActivityGetNextRecord(buffer, validSize, &record);
if (status == CUPTI_SUCCESS) {
if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
// TODO(andportnoy) handle multi-GPU
CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record;
// Convert integer nanoseconds to floating point milliseconds to match
// the interface of the events-based profiler.
double duration_ms = (kernel->end - kernel->start) / 1e6;
profiler_state.timings.push_back(
std::make_tuple(kernel->name, duration_ms));
}
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
// no more records available
break;
} else {
THROW_IF_CUPTI_ERROR(status);
}
}
}
NB_MODULE(_mosaic_gpu_ext, m) {
m.def("registrations", []() {
return nb::make_tuple(
@ -139,6 +214,35 @@ NB_MODULE(_mosaic_gpu_ext, m) {
}
}
});
m.def("_cupti_init", []() {
profiler_state.timings.clear();
// Ok to pass nullptr for the callback here because we don't register any
// callbacks through cuptiEnableCallback.
auto subscribe_result = cuptiSubscribe(
&profiler_state.subscriber, /*callback=*/nullptr, /*userdata=*/nullptr);
if (subscribe_result == CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED) {
THROW(
"Attempted to subscribe to CUPTI while another subscriber, such as "
"Nsight Systems or Nsight Compute, is active. CUPTI backend of the "
"Mosaic GPU profiler cannot be used in that mode since CUPTI does "
"not support multiple subscribers.");
}
THROW_IF_CUPTI_ERROR(subscribe_result, "failed to subscribe to CUPTI");
THROW_IF_CUPTI_ERROR(
cuptiActivityRegisterCallbacks(callback_request, callback_complete),
"failed to register CUPTI activity callbacks");
THROW_IF_CUPTI_ERROR(
cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL),
"failed to enable tracking of kernel activity by CUPTI");
});
m.def("_cupti_get_timings", []() {
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
"failed to unsubscribe from CUPTI");
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
"failed to flush CUPTI activity buffers");
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
return profiler_state.timings;
});
}
} // namespace

View File

@ -88,3 +88,17 @@ jax_multiplatform_test(
"//jax/experimental/mosaic/gpu/examples:flash_attention",
] + py_deps("absl/testing"),
)
jax_multiplatform_test(
name = "profiler_cupti_test",
srcs = ["profiler_cupti_test.py"],
enable_backends = [],
enable_configs = ["gpu_h100"],
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing"),
tags = [
"noasan", # CUPTI leaks memory
"nomsan",
],
)

View File

@ -1691,9 +1691,10 @@ class FragmentedArrayTest(TestCase):
class ProfilerTest(TestCase):
def test_measure(self):
def test_measure_events_explicit(self):
x = jnp.arange(1024 * 1024)
profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test
_, runtime_ms = profiler.measure(lambda x, y: x + y, mode="events")(x, x)
self.assertIsInstance(runtime_ms, float)
def test_profile(self):
def kernel(ctx, src, dst, _):

View File

@ -0,0 +1,88 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Mosaic GPU CUPTI-based profiler."""
from absl.testing import absltest, parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
import jax.numpy as jnp
try:
import jax._src.lib.mosaic_gpu # noqa: F401
HAS_MOSAIC_GPU = True
except ImportError:
HAS_MOSAIC_GPU = False
else:
from jax.experimental.mosaic.gpu import profiler
# ruff: noqa: F405
# pylint: disable=g-complex-comprehension
config.parse_flags_with_absl()
class ProfilerCuptiTest(parameterized.TestCase):
def setUp(self):
if not HAS_MOSAIC_GPU:
self.skipTest("jaxlib built without Mosaic GPU")
if (not jtu.test_device_matches(["cuda"])):
self.skipTest("Only works on NVIDIA GPUs")
super().setUp()
self.x = jnp.arange(1024 * 1024)
self.f = lambda x: 2*x
def test_measure_cupti_explicit(self):
_, runtime_ms = profiler.measure(self.f, mode="cupti")(self.x)
self.assertIsInstance(runtime_ms, float)
def test_measure_per_kernel(self):
_, runtimes_ms = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x)
for item in runtimes_ms:
self.assertIsInstance(item, tuple)
self.assertEqual(len(item), 2)
name, runtime_ms = item
self.assertIsInstance(name, str)
self.assertIsInstance(runtime_ms, float)
def test_measure_cupti_repeated(self):
f_profiled = profiler.measure(self.f, mode="cupti")
n = 3
timings = [f_profiled(self.x)[1] for _ in range(n)]
for item in timings:
self.assertIsInstance(item, float)
def test_measure_repeated_interleaved(self):
# test that kernels run outside of measure() are not captured
_, timings = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x)
self.assertEqual(len(timings), 1)
self.f(self.x)
_, timings = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x)
self.assertEqual(len(timings), 1)
def test_measure_double_subscription(self):
# This needs to run in a separate process, otherwise it affects the
# outcomes of other tests since CUPTI state is global.
self.skipTest("Must run in a separate process from other profiler tests")
# Initialize profiler manually, which subscribes to CUPTI. There can only
# be one CUPTI subscriber at a time.
jax._src.lib.mosaic_gpu._mosaic_gpu_ext._cupti_init()
with self.assertRaisesRegex(RuntimeError,
"Attempted to subscribe to CUPTI while another subscriber, "
"such as Nsight Systems or Nsight Compute, is active."):
profiler.measure(self.f, aggregate=False)(self.x)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())