jax.profiler: remove deprecated functions

This commit is contained in:
Jake VanderPlas 2022-07-11 16:17:35 -07:00
parent 12ce369c94
commit 2543542fa8
3 changed files with 0 additions and 39 deletions

View File

@ -34,13 +34,3 @@ profiling features.
device_memory_profile
save_device_memory_profile
Deprecated functions
--------------------
.. autosummary::
:toctree: _autosummary
trace_function
TraceContext
StepTraceContext

View File

@ -234,15 +234,6 @@ class TraceAnnotation(xla_client.profiler.TraceMe):
pass
# TODO: remove this sometime after jax 0.2.11 is released
class TraceContext(TraceAnnotation):
def __init__(self, *args, **kwargs):
warnings.warn(
"TraceContext has been renamed to TraceAnnotation. This alias "
"will eventually be removed; please update your code.")
super().__init__(*args, **kwargs)
class StepTraceAnnotation(TraceAnnotation):
"""Context manager that generates a step trace event in the profiler.
@ -269,15 +260,6 @@ class StepTraceAnnotation(TraceAnnotation):
super().__init__(name, _r=1, **kwargs)
# TODO: remove this sometime after jax 0.2.11 is released
class StepTraceContext(StepTraceAnnotation):
def __init__(self, *args, **kwargs):
warnings.warn(
"StepTraceContext has been renamed to StepTraceAnnotation. This alias "
"will eventually be removed; please update your code.")
super().__init__(*args, **kwargs)
def annotate_function(func: Callable, name: Optional[str] = None,
**decorator_kwargs):
"""Decorator that generates a trace event for the execution of a function.
@ -314,14 +296,6 @@ def annotate_function(func: Callable, name: Optional[str] = None,
return wrapper
# TODO: remove this sometime after jax 0.2.11 is released
def trace_function(*args, **kwargs):
warnings.warn(
"trace_function has been renamed to annotate_function. This alias "
"will eventually be removed; please update your code.")
return annotate_function(*args, **kwargs)
def device_memory_profile(backend: Optional[str] = None) -> bytes:
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.

View File

@ -14,9 +14,7 @@
from jax._src.profiler import (
StepTraceAnnotation as StepTraceAnnotation,
StepTraceContext as StepTraceContext,
TraceAnnotation as TraceAnnotation,
TraceContext as TraceContext,
device_memory_profile as device_memory_profile,
save_device_memory_profile as save_device_memory_profile,
start_server as start_server,
@ -25,5 +23,4 @@ from jax._src.profiler import (
stop_trace as stop_trace,
trace as trace,
annotate_function as annotate_function,
trace_function as trace_function,
)