mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation
This commit is contained in:
parent
12b45b3235
commit
cc22334c21
@ -576,7 +576,7 @@ def benchmark_and_verify(
|
||||
head_dim=head_dim,
|
||||
**kwargs,
|
||||
)
|
||||
out, runtime = profiler.measure(f, q[0], k[0], v[0])
|
||||
out, runtime = profiler.measure(f)(q[0], k[0], v[0])
|
||||
out = out[None]
|
||||
|
||||
@jax.jit
|
||||
|
@ -360,7 +360,7 @@ def verify(
|
||||
wgmma_impl=WGMMADefaultImpl,
|
||||
profiler_spec=prof_spec,
|
||||
)
|
||||
z, runtime = profiler.measure(f, x, y)
|
||||
z, runtime = profiler.measure(f)(x, y)
|
||||
|
||||
if rhs_transpose:
|
||||
dimension_numbers = ((1,), (1,)), ((), ())
|
||||
@ -382,7 +382,7 @@ def verify(
|
||||
preferred_element_type=out_dtype,
|
||||
).astype(out_dtype)
|
||||
|
||||
ref, ref_runtime = profiler.measure(ref_f, x, y)
|
||||
ref, ref_runtime = profiler.measure(ref_f)(x, y)
|
||||
np.testing.assert_allclose(
|
||||
z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
|
||||
)
|
||||
@ -426,7 +426,7 @@ if __name__ == "__main__":
|
||||
f = build_kernel(
|
||||
m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs
|
||||
)
|
||||
_, runtime = profiler.measure(f, x, y)
|
||||
_, runtime = profiler.measure(f)(x, y)
|
||||
except ValueError as e:
|
||||
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
|
||||
raise
|
||||
|
@ -69,20 +69,9 @@ def _event_elapsed(start_event, end_event):
|
||||
)(start_event, end_event)
|
||||
|
||||
|
||||
def measure(
|
||||
def _measure_events(
|
||||
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
|
||||
) -> tuple[T, float]:
|
||||
"""Measures the time it takes to execute the function on the GPU.
|
||||
|
||||
Args:
|
||||
f: The function to measure. It must accept at least one argument and return
|
||||
at least one output to be measurable.
|
||||
*args: The arguments to pass to ``f``.
|
||||
**kwargs: The keyword arguments to pass to ``f``.
|
||||
|
||||
Returns:
|
||||
The return value of ``f`` and the elapsed time in milliseconds.
|
||||
"""
|
||||
if not has_registrations:
|
||||
raise RuntimeError(
|
||||
"This function requires jaxlib >=0.4.36 with CUDA support."
|
||||
@ -109,6 +98,91 @@ def measure(
|
||||
return outs, float(elapsed)
|
||||
|
||||
|
||||
def _measure_cupti(f, aggregate):
|
||||
def wrapper(*args, **kwargs):
|
||||
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init()
|
||||
try:
|
||||
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs))
|
||||
finally:
|
||||
timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings()
|
||||
if not timings:
|
||||
return results, None
|
||||
elif aggregate:
|
||||
return results, sum(item[1] for item in timings)
|
||||
else:
|
||||
return results, timings
|
||||
return wrapper
|
||||
|
||||
|
||||
def measure(f: Callable, *, mode: str = "cupti", aggregate: bool = True
|
||||
) -> Callable:
|
||||
"""Sets up a function ``f`` for profiling on GPU.
|
||||
|
||||
``measure`` is a higher-order function that augments the argument ``f`` to
|
||||
return GPU runtime in milliseconds, in addition to its proper outputs.
|
||||
|
||||
Args:
|
||||
f: The function to measure. It must accept at least one argument and return
|
||||
at least one output to be measurable.
|
||||
mode: The mode of operation. Possible values are:
|
||||
|
||||
- "cupti", for CUPTI-based profiling.
|
||||
- "events", for CUDA events-based profiling.
|
||||
|
||||
The two modes use different measurement methodologies and should not be
|
||||
treated as interchangeable backends. See the Notes section for important
|
||||
discussion.
|
||||
aggregate: Whether to report an aggregate runtime. When ``False`` (only
|
||||
supported by ``mode="cupti"``), the per-kernel timings are returned as a
|
||||
list of tuples ``(<kernel name>, <runtime in ms>)``.
|
||||
|
||||
Returns:
|
||||
A new function ``g`` that returns the measured GPU runtime as its last
|
||||
additional output. Otherwise ``g`` accepts the same inputs and returns the
|
||||
same outputs as ``f``.
|
||||
|
||||
Notes:
|
||||
`CUPTI (CUDA Profiling Tools Interface)
|
||||
<https://docs.nvidia.com/cupti/index.html>`_ is a high-accuracy,
|
||||
high-precision profiling and tracing API, used in particular by Nsight
|
||||
Systems and Nsight Compute. When using ``measure`` with ``mode="cupti"``,
|
||||
device (GPU) execution runtimes are recorded for each kernel launched
|
||||
during the execution of the function. In that mode, setting
|
||||
``aggregate=True`` will sum the individual kernel runtimes to arrive at an
|
||||
aggregate measurement. The "gaps" between the kernels when the device is
|
||||
idle are not included in the aggregate.
|
||||
|
||||
The CUPTI API only allows a single "subscriber". This means that the
|
||||
CUPTI-based profiler will fail when the program is run using tools that
|
||||
make use of CUPTI, such as CUDA-GDB, Compute Sanitizer, Nsight Systems, or
|
||||
Nsight Compute.
|
||||
|
||||
``mode="events"`` uses a different approach: a CUDA event is recorded
|
||||
before and after the function ``f`` is executed. The reported runtime is
|
||||
the time elapsed between the two events. In particular, included in the
|
||||
measurement are:
|
||||
|
||||
- any potential "gaps" between the kernels when the device is idle
|
||||
- any potential "gaps" between the "before" event and the start of the
|
||||
first kernel, or between the end of the last kernel and the "after" event
|
||||
|
||||
In an attempt to minimize the second effect, internally the events-based
|
||||
implementation may execute ``f`` more than once to "warm up" and exclude
|
||||
compilation time from the measurement.
|
||||
"""
|
||||
match mode:
|
||||
case "cupti":
|
||||
return _measure_cupti(f, aggregate)
|
||||
case "events":
|
||||
if not aggregate:
|
||||
raise ValueError(f"{aggregate=} is not supported with {mode=}")
|
||||
def measure_events_wrapper(*args, **kwargs):
|
||||
return _measure_events(f, *args, **kwargs)
|
||||
return measure_events_wrapper
|
||||
case _:
|
||||
raise ValueError(f"Unrecognized profiler mode {mode}")
|
||||
|
||||
|
||||
class ProfilerSpec:
|
||||
ENTER = 0
|
||||
EXIT = 1 << 31
|
||||
|
@ -274,7 +274,7 @@ def main(unused_argv):
|
||||
for block_kv in (256, 128, 64):
|
||||
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
|
||||
try:
|
||||
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
|
||||
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
|
||||
if seq_len < 32768:
|
||||
out_ref = attention_reference(q, k, v)
|
||||
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
|
||||
|
@ -13,11 +13,17 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <new>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/tuple.h"
|
||||
#include "nanobind/stl/vector.h"
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
@ -118,6 +124,75 @@ XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) {
|
||||
return kEventElapsed->Call(call_frame);
|
||||
}
|
||||
|
||||
#define THROW(...) \
|
||||
do { \
|
||||
throw std::runtime_error( \
|
||||
absl::StrCat("Mosaic GPU profiler error: ", __VA_ARGS__)); \
|
||||
} while (0)
|
||||
|
||||
#define THROW_IF(expr, ...) \
|
||||
do { \
|
||||
if (expr) THROW(__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define THROW_IF_CUPTI_ERROR(expr, ...) \
|
||||
do { \
|
||||
CUptiResult _result = (expr); \
|
||||
if (_result != CUPTI_SUCCESS) { \
|
||||
const char* s; \
|
||||
cuptiGetErrorMessage(_result, &s); \
|
||||
THROW(s, ": " __VA_OPT__(, ) __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// CUPTI can only have one subscriber per process, so it's ok to make the
|
||||
// profiler state global.
|
||||
struct {
|
||||
CUpti_SubscriberHandle subscriber;
|
||||
std::vector<std::tuple<const char* /*kernel_name*/, double /*ms*/>> timings;
|
||||
} profiler_state;
|
||||
|
||||
void callback_request(uint8_t** buffer, size_t* size, size_t* maxNumRecords) {
|
||||
// 10 MiB buffer size is generous but somewhat arbitrary, it's at the upper
|
||||
// bound of what's recommended in CUPTI documentation:
|
||||
// https://docs.nvidia.com/cupti/main/main.html#cupti-callback-api:~:text=For%20typical%20workloads%2C%20it%E2%80%99s%20suggested%20to%20choose%20a%20size%20between%201%20and%2010%20MB.
|
||||
const int buffer_size = 10 * (1 << 20);
|
||||
// 8 byte alignment is specified in the official CUPTI code samples, see
|
||||
// extras/CUPTI/samples/common/helper_cupti_activity.h in your CUDA
|
||||
// installation.
|
||||
*buffer = new (std::align_val_t(8)) uint8_t[buffer_size];
|
||||
*size = buffer_size;
|
||||
*maxNumRecords = 0;
|
||||
}
|
||||
|
||||
void callback_complete(CUcontext context, uint32_t streamId,
|
||||
uint8_t* buffer, size_t size, size_t validSize) {
|
||||
// take ownership of the buffer once CUPTI is done using it
|
||||
absl::Cleanup cleanup = [buffer]() {
|
||||
operator delete[](buffer, std::align_val_t(8));
|
||||
};
|
||||
CUpti_Activity* record = nullptr;
|
||||
while (true) {
|
||||
CUptiResult status = cuptiActivityGetNextRecord(buffer, validSize, &record);
|
||||
if (status == CUPTI_SUCCESS) {
|
||||
if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
|
||||
// TODO(andportnoy) handle multi-GPU
|
||||
CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record;
|
||||
// Convert integer nanoseconds to floating point milliseconds to match
|
||||
// the interface of the events-based profiler.
|
||||
double duration_ms = (kernel->end - kernel->start) / 1e6;
|
||||
profiler_state.timings.push_back(
|
||||
std::make_tuple(kernel->name, duration_ms));
|
||||
}
|
||||
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
|
||||
// no more records available
|
||||
break;
|
||||
} else {
|
||||
THROW_IF_CUPTI_ERROR(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
m.def("registrations", []() {
|
||||
return nb::make_tuple(
|
||||
@ -139,6 +214,35 @@ NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
}
|
||||
}
|
||||
});
|
||||
m.def("_cupti_init", []() {
|
||||
profiler_state.timings.clear();
|
||||
// Ok to pass nullptr for the callback here because we don't register any
|
||||
// callbacks through cuptiEnableCallback.
|
||||
auto subscribe_result = cuptiSubscribe(
|
||||
&profiler_state.subscriber, /*callback=*/nullptr, /*userdata=*/nullptr);
|
||||
if (subscribe_result == CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED) {
|
||||
THROW(
|
||||
"Attempted to subscribe to CUPTI while another subscriber, such as "
|
||||
"Nsight Systems or Nsight Compute, is active. CUPTI backend of the "
|
||||
"Mosaic GPU profiler cannot be used in that mode since CUPTI does "
|
||||
"not support multiple subscribers.");
|
||||
}
|
||||
THROW_IF_CUPTI_ERROR(subscribe_result, "failed to subscribe to CUPTI");
|
||||
THROW_IF_CUPTI_ERROR(
|
||||
cuptiActivityRegisterCallbacks(callback_request, callback_complete),
|
||||
"failed to register CUPTI activity callbacks");
|
||||
THROW_IF_CUPTI_ERROR(
|
||||
cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL),
|
||||
"failed to enable tracking of kernel activity by CUPTI");
|
||||
});
|
||||
m.def("_cupti_get_timings", []() {
|
||||
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
|
||||
"failed to unsubscribe from CUPTI");
|
||||
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
|
||||
"failed to flush CUPTI activity buffers");
|
||||
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
|
||||
return profiler_state.timings;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -88,3 +88,17 @@ jax_multiplatform_test(
|
||||
"//jax/experimental/mosaic/gpu/examples:flash_attention",
|
||||
] + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "profiler_cupti_test",
|
||||
srcs = ["profiler_cupti_test.py"],
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
] + py_deps("absl/testing"),
|
||||
tags = [
|
||||
"noasan", # CUPTI leaks memory
|
||||
"nomsan",
|
||||
],
|
||||
)
|
||||
|
@ -1691,9 +1691,10 @@ class FragmentedArrayTest(TestCase):
|
||||
|
||||
class ProfilerTest(TestCase):
|
||||
|
||||
def test_measure(self):
|
||||
def test_measure_events_explicit(self):
|
||||
x = jnp.arange(1024 * 1024)
|
||||
profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test
|
||||
_, runtime_ms = profiler.measure(lambda x, y: x + y, mode="events")(x, x)
|
||||
self.assertIsInstance(runtime_ms, float)
|
||||
|
||||
def test_profile(self):
|
||||
def kernel(ctx, src, dst, _):
|
||||
|
88
tests/mosaic/profiler_cupti_test.py
Normal file
88
tests/mosaic/profiler_cupti_test.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Mosaic GPU CUPTI-based profiler."""
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
try:
|
||||
import jax._src.lib.mosaic_gpu # noqa: F401
|
||||
HAS_MOSAIC_GPU = True
|
||||
except ImportError:
|
||||
HAS_MOSAIC_GPU = False
|
||||
else:
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
|
||||
|
||||
# ruff: noqa: F405
|
||||
# pylint: disable=g-complex-comprehension
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class ProfilerCuptiTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not HAS_MOSAIC_GPU:
|
||||
self.skipTest("jaxlib built without Mosaic GPU")
|
||||
if (not jtu.test_device_matches(["cuda"])):
|
||||
self.skipTest("Only works on NVIDIA GPUs")
|
||||
super().setUp()
|
||||
self.x = jnp.arange(1024 * 1024)
|
||||
self.f = lambda x: 2*x
|
||||
|
||||
def test_measure_cupti_explicit(self):
|
||||
_, runtime_ms = profiler.measure(self.f, mode="cupti")(self.x)
|
||||
self.assertIsInstance(runtime_ms, float)
|
||||
|
||||
def test_measure_per_kernel(self):
|
||||
_, runtimes_ms = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x)
|
||||
for item in runtimes_ms:
|
||||
self.assertIsInstance(item, tuple)
|
||||
self.assertEqual(len(item), 2)
|
||||
name, runtime_ms = item
|
||||
self.assertIsInstance(name, str)
|
||||
self.assertIsInstance(runtime_ms, float)
|
||||
|
||||
def test_measure_cupti_repeated(self):
|
||||
f_profiled = profiler.measure(self.f, mode="cupti")
|
||||
n = 3
|
||||
timings = [f_profiled(self.x)[1] for _ in range(n)]
|
||||
for item in timings:
|
||||
self.assertIsInstance(item, float)
|
||||
|
||||
def test_measure_repeated_interleaved(self):
|
||||
# test that kernels run outside of measure() are not captured
|
||||
_, timings = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x)
|
||||
self.assertEqual(len(timings), 1)
|
||||
self.f(self.x)
|
||||
_, timings = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x)
|
||||
self.assertEqual(len(timings), 1)
|
||||
|
||||
def test_measure_double_subscription(self):
|
||||
# This needs to run in a separate process, otherwise it affects the
|
||||
# outcomes of other tests since CUPTI state is global.
|
||||
self.skipTest("Must run in a separate process from other profiler tests")
|
||||
# Initialize profiler manually, which subscribes to CUPTI. There can only
|
||||
# be one CUPTI subscriber at a time.
|
||||
jax._src.lib.mosaic_gpu._mosaic_gpu_ext._cupti_init()
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Attempted to subscribe to CUPTI while another subscriber, "
|
||||
"such as Nsight Systems or Nsight Compute, is active."):
|
||||
profiler.measure(self.f, aggregate=False)(self.x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user