mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

This PR is a follow up to #18881. The changes were generated by adding from __future__ import annotations to the files which did not already have them and running pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for instrumenting code.
|
|
|
|
Code points can be marked as a named event. Every time an event is reached
|
|
during program execution, the registered listeners will be invoked.
|
|
|
|
A typical listener callback is to send an event to a metrics collector for
|
|
aggregation/exporting.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Protocol
|
|
|
|
|
|
class EventListenerWithMetadata(Protocol):
|
|
|
|
def __call__(self, event: str, **kwargs: str | int) -> None:
|
|
...
|
|
|
|
|
|
class EventDurationListenerWithMetadata(Protocol):
|
|
|
|
def __call__(self, event: str, duration_secs: float,
|
|
**kwargs: str | int) -> None:
|
|
...
|
|
|
|
|
|
_event_listeners: list[EventListenerWithMetadata] = []
|
|
_event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = []
|
|
|
|
|
|
def record_event(event: str, **kwargs: str | int) -> None:
|
|
"""Record an event.
|
|
|
|
If **kwargs are specified, all of the named arguments have to be passed in the
|
|
same order across all invocations of this method for the same event.
|
|
"""
|
|
for callback in _event_listeners:
|
|
callback(event, **kwargs)
|
|
|
|
|
|
def record_event_duration_secs(event: str, duration: float,
|
|
**kwargs: str | int) -> None:
|
|
"""Record an event duration in seconds (float).
|
|
|
|
If **kwargs are specified, all of the named arguments have to be passed in the
|
|
same order across all invocations of this method for the same event.
|
|
"""
|
|
for callback in _event_duration_secs_listeners:
|
|
callback(event, duration, **kwargs)
|
|
|
|
|
|
def register_event_listener(
|
|
callback: EventListenerWithMetadata,
|
|
) -> None:
|
|
"""Register a callback to be invoked during record_event()."""
|
|
_event_listeners.append(callback)
|
|
|
|
|
|
def register_event_duration_secs_listener(
|
|
callback : EventDurationListenerWithMetadata) -> None:
|
|
"""Register a callback to be invoked during record_event_duration_secs()."""
|
|
_event_duration_secs_listeners.append(callback)
|
|
|
|
def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]:
|
|
"""Get event duration listeners."""
|
|
return list(_event_duration_secs_listeners)
|
|
|
|
def get_event_listeners() -> list[EventListenerWithMetadata]:
|
|
"""Get event listeners."""
|
|
return list(_event_listeners)
|
|
|
|
def clear_event_listeners():
|
|
"""Clear event listeners."""
|
|
global _event_listeners, _event_duration_secs_listeners
|
|
_event_listeners = []
|
|
_event_duration_secs_listeners = []
|
|
|
|
def _unregister_event_duration_listener_by_callback(
|
|
callback: EventDurationListenerWithMetadata) -> None:
|
|
"""Unregister an event duration listener by callback.
|
|
|
|
This function is supposed to be called for testing only.
|
|
"""
|
|
assert callback in _event_duration_secs_listeners
|
|
_event_duration_secs_listeners.remove(callback)
|
|
|
|
def _unregister_event_duration_listener_by_index(index: int) -> None:
|
|
"""Unregister an event duration listener by index.
|
|
|
|
This function is supposed to be called for testing only.
|
|
"""
|
|
size = len(_event_duration_secs_listeners)
|
|
assert -size <= index < size
|
|
del _event_duration_secs_listeners[index]
|
|
|
|
def _unregister_event_listener_by_callback(
|
|
callback: EventListenerWithMetadata) -> None:
|
|
"""Unregister an event listener by callback.
|
|
|
|
This function is supposed to be called for testing only.
|
|
"""
|
|
assert callback in _event_listeners
|
|
_event_listeners.remove(callback)
|