rocm_jax/jax/_src/source_info_util.py

125 lines
4.1 KiB
Python
Raw Normal View History

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import itertools
import os.path
import threading
from typing import Optional, Iterator, NamedTuple
import jax.version
from jax._src.lib import xla_client
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
Traceback = xla_client.Traceback
Frame = xla_client.Frame
_exclude_paths = [os.path.dirname(jax.version.__file__)]
def register_exclusion(path):
_exclude_paths.append(path)
class SourceInfo(NamedTuple):
traceback: Optional[Traceback]
def replace(self, *, traceback: Optional[Traceback] = None) -> 'SourceInfo':
traceback = traceback or self.traceback
return self._replace(traceback=traceback)
def new_source_info() -> SourceInfo:
return SourceInfo(None)
[JAX] Add support for generating "equation profiles" in JAX. An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result. For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives: ``` $ pprof --tags /tmp/myprof Main binary filename not available. primitive: Total 6062.0 1509.0 (24.89%): mul 936.0 (15.44%): add 589.0 ( 9.72%): reshape 492.0 ( 8.12%): div 485.0 ( 8.00%): broadcast_in_dim 330.0 ( 5.44%): reduce_sum 322.0 ( 5.31%): integer_pow 230.0 ( 3.79%): add_any 174.0 ( 2.87%): convert_element_type 160.0 ( 2.64%): select 158.0 ( 2.61%): conv_general_dilated 116.0 ( 1.91%): sub 110.0 ( 1.81%): eq 110.0 ( 1.81%): neg 104.0 ( 1.72%): max 53.0 ( 0.87%): rsqrt 52.0 ( 0.86%): rev 49.0 ( 0.81%): custom_jvp_call_jaxpr 49.0 ( 0.81%): gt 5.0 (0.082%): xla_call 4.0 (0.066%): min 3.0 (0.049%): dot_general 3.0 (0.049%): lt 2.0 (0.033%): cos 2.0 (0.033%): exp 2.0 (0.033%): iota 2.0 (0.033%): log 2.0 (0.033%): psum 2.0 (0.033%): reduce_max 2.0 (0.033%): stop_gradient 1.0 (0.016%): argmax 1.0 (0.016%): reduce_window_max 1.0 (0.016%): select_and_scatter_add 1.0 (0.016%): transpose 1.0 (0.016%): xla_pmap ``` Or the lines of code that generated the most equations: ``` $ pprof --text /tmp/myprof Main binary filename not available. Type: equations Showing nodes accounting for 6038, 99.60% of 6062 total Dropped 5 nodes (cum <= 30) flat flat% sum% cum cum% 1537 25.35% 25.35% 1537 25.35% _compute_stats 1484 24.48% 49.84% 1484 24.48% _normalize 849 14.01% 63.84% 6062 100% __call__ 644 10.62% 74.46% 644 10.62% <unknown> 483 7.97% 82.43% 483 7.97% <unknown> 392 6.47% 88.90% 6061 100% train_step 324 5.34% 94.24% 324 5.34% <unknown> 161 2.66% 96.90% 161 2.66% <unknown> 57 0.94% 97.84% 4292 70.80% loss_fn 52 0.86% 98.70% 52 0.86% schedule 39 0.64% 99.34% 39 0.64% softmax_cross_entropy 8 0.13% 99.47% 30 0.49% compute_metrics 6 0.099% 99.57% 61 1.01% cross_entropy_loss 1 0.016% 99.59% 1321 21.79% apply_gradients 1 0.016% 99.60% 6062 100% train_and_evaluate 0 0% 99.60% 6062 100% <unknown> 0 0% 99.60% 6062 100% __init__ 0 0% 99.60% 3872 63.87% _call_wrapped_method 0 0% 99.60% 6062 100% _run_and_get_tests_result 0 0% 99.60% 6062 100% _run_code_in_main 0 0% 99.60% 6062 100% _run_in_app 0 0% 99.60% 6062 100% _run_main 0 0% 99.60% 3872 63.87% apply 0 0% 99.60% 161 2.66% apply_updates 0 0% 99.60% 6062 100% main 0 0% 99.60% 6062 100% main_function 0 0% 99.60% 6062 100% run 0 0% 99.60% 6062 100% runTests 0 0% 99.60% 6062 100% run_filename_as_main 0 0% 99.60% 6062 100% run_tests 0 0% 99.60% 3872 63.87% scope_fn 0 0% 99.60% 6062 100% test_train_and_evaluate 0 0% 99.60% 1159 19.12% update_fn 0 0% 99.60% 3872 63.87% wrapped_fn 0 0% 99.60% 3872 63.87% wrapped_module_method 0 0% 99.60% 3872 63.87% wrapper ``` I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization. ``` pprof --tagleaf=primitive --http=:8080 myprof ``` [XLA:Python] Add helpers to Traceback and for working with pprof profiles. * Define hash and equality operators on Tracebacks. * Add functions for converting JSON to and from pprof profile protocol buffers. * Add a helper method that exposes PyCode_Addr2Line to Python. PiperOrigin-RevId: 421395346
2022-01-12 14:27:17 -08:00
def is_user_filename(filename: str) -> bool:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
[JAX] Add support for generating "equation profiles" in JAX. An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result. For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives: ``` $ pprof --tags /tmp/myprof Main binary filename not available. primitive: Total 6062.0 1509.0 (24.89%): mul 936.0 (15.44%): add 589.0 ( 9.72%): reshape 492.0 ( 8.12%): div 485.0 ( 8.00%): broadcast_in_dim 330.0 ( 5.44%): reduce_sum 322.0 ( 5.31%): integer_pow 230.0 ( 3.79%): add_any 174.0 ( 2.87%): convert_element_type 160.0 ( 2.64%): select 158.0 ( 2.61%): conv_general_dilated 116.0 ( 1.91%): sub 110.0 ( 1.81%): eq 110.0 ( 1.81%): neg 104.0 ( 1.72%): max 53.0 ( 0.87%): rsqrt 52.0 ( 0.86%): rev 49.0 ( 0.81%): custom_jvp_call_jaxpr 49.0 ( 0.81%): gt 5.0 (0.082%): xla_call 4.0 (0.066%): min 3.0 (0.049%): dot_general 3.0 (0.049%): lt 2.0 (0.033%): cos 2.0 (0.033%): exp 2.0 (0.033%): iota 2.0 (0.033%): log 2.0 (0.033%): psum 2.0 (0.033%): reduce_max 2.0 (0.033%): stop_gradient 1.0 (0.016%): argmax 1.0 (0.016%): reduce_window_max 1.0 (0.016%): select_and_scatter_add 1.0 (0.016%): transpose 1.0 (0.016%): xla_pmap ``` Or the lines of code that generated the most equations: ``` $ pprof --text /tmp/myprof Main binary filename not available. Type: equations Showing nodes accounting for 6038, 99.60% of 6062 total Dropped 5 nodes (cum <= 30) flat flat% sum% cum cum% 1537 25.35% 25.35% 1537 25.35% _compute_stats 1484 24.48% 49.84% 1484 24.48% _normalize 849 14.01% 63.84% 6062 100% __call__ 644 10.62% 74.46% 644 10.62% <unknown> 483 7.97% 82.43% 483 7.97% <unknown> 392 6.47% 88.90% 6061 100% train_step 324 5.34% 94.24% 324 5.34% <unknown> 161 2.66% 96.90% 161 2.66% <unknown> 57 0.94% 97.84% 4292 70.80% loss_fn 52 0.86% 98.70% 52 0.86% schedule 39 0.64% 99.34% 39 0.64% softmax_cross_entropy 8 0.13% 99.47% 30 0.49% compute_metrics 6 0.099% 99.57% 61 1.01% cross_entropy_loss 1 0.016% 99.59% 1321 21.79% apply_gradients 1 0.016% 99.60% 6062 100% train_and_evaluate 0 0% 99.60% 6062 100% <unknown> 0 0% 99.60% 6062 100% __init__ 0 0% 99.60% 3872 63.87% _call_wrapped_method 0 0% 99.60% 6062 100% _run_and_get_tests_result 0 0% 99.60% 6062 100% _run_code_in_main 0 0% 99.60% 6062 100% _run_in_app 0 0% 99.60% 6062 100% _run_main 0 0% 99.60% 3872 63.87% apply 0 0% 99.60% 161 2.66% apply_updates 0 0% 99.60% 6062 100% main 0 0% 99.60% 6062 100% main_function 0 0% 99.60% 6062 100% run 0 0% 99.60% 6062 100% runTests 0 0% 99.60% 6062 100% run_filename_as_main 0 0% 99.60% 6062 100% run_tests 0 0% 99.60% 3872 63.87% scope_fn 0 0% 99.60% 6062 100% test_train_and_evaluate 0 0% 99.60% 1159 19.12% update_fn 0 0% 99.60% 3872 63.87% wrapped_fn 0 0% 99.60% 3872 63.87% wrapped_module_method 0 0% 99.60% 3872 63.87% wrapper ``` I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization. ``` pprof --tagleaf=primitive --http=:8080 myprof ``` [XLA:Python] Add helpers to Traceback and for working with pprof profiles. * Define hash and equality operators on Tracebacks. * Add functions for converting JSON to and from pprof profile protocol buffers. * Add a helper method that exposes PyCode_Addr2Line to Python. PiperOrigin-RevId: 421395346
2022-01-12 14:27:17 -08:00
return (filename.endswith("_test.py") or
not any(filename.startswith(p) for p in _exclude_paths))
def user_frames(source_info: SourceInfo) -> Iterator[Frame]:
"""Iterator over the user's frames."""
# Guess the user's frame is the innermost frame not in the jax source tree
# We don't use traceback_util.path_starts_with because that incurs filesystem
# access, which may be slow; we call this function when e.g. adding source
# provenance annotations to XLA lowerings, so we don't want to incur the cost.
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
# We consider files that end with _test.py as user frames, to allow testing
# this mechanism from tests.
traceback = source_info.traceback
return (x for x in (traceback.frames if traceback else [])
[JAX] Add support for generating "equation profiles" in JAX. An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result. For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives: ``` $ pprof --tags /tmp/myprof Main binary filename not available. primitive: Total 6062.0 1509.0 (24.89%): mul 936.0 (15.44%): add 589.0 ( 9.72%): reshape 492.0 ( 8.12%): div 485.0 ( 8.00%): broadcast_in_dim 330.0 ( 5.44%): reduce_sum 322.0 ( 5.31%): integer_pow 230.0 ( 3.79%): add_any 174.0 ( 2.87%): convert_element_type 160.0 ( 2.64%): select 158.0 ( 2.61%): conv_general_dilated 116.0 ( 1.91%): sub 110.0 ( 1.81%): eq 110.0 ( 1.81%): neg 104.0 ( 1.72%): max 53.0 ( 0.87%): rsqrt 52.0 ( 0.86%): rev 49.0 ( 0.81%): custom_jvp_call_jaxpr 49.0 ( 0.81%): gt 5.0 (0.082%): xla_call 4.0 (0.066%): min 3.0 (0.049%): dot_general 3.0 (0.049%): lt 2.0 (0.033%): cos 2.0 (0.033%): exp 2.0 (0.033%): iota 2.0 (0.033%): log 2.0 (0.033%): psum 2.0 (0.033%): reduce_max 2.0 (0.033%): stop_gradient 1.0 (0.016%): argmax 1.0 (0.016%): reduce_window_max 1.0 (0.016%): select_and_scatter_add 1.0 (0.016%): transpose 1.0 (0.016%): xla_pmap ``` Or the lines of code that generated the most equations: ``` $ pprof --text /tmp/myprof Main binary filename not available. Type: equations Showing nodes accounting for 6038, 99.60% of 6062 total Dropped 5 nodes (cum <= 30) flat flat% sum% cum cum% 1537 25.35% 25.35% 1537 25.35% _compute_stats 1484 24.48% 49.84% 1484 24.48% _normalize 849 14.01% 63.84% 6062 100% __call__ 644 10.62% 74.46% 644 10.62% <unknown> 483 7.97% 82.43% 483 7.97% <unknown> 392 6.47% 88.90% 6061 100% train_step 324 5.34% 94.24% 324 5.34% <unknown> 161 2.66% 96.90% 161 2.66% <unknown> 57 0.94% 97.84% 4292 70.80% loss_fn 52 0.86% 98.70% 52 0.86% schedule 39 0.64% 99.34% 39 0.64% softmax_cross_entropy 8 0.13% 99.47% 30 0.49% compute_metrics 6 0.099% 99.57% 61 1.01% cross_entropy_loss 1 0.016% 99.59% 1321 21.79% apply_gradients 1 0.016% 99.60% 6062 100% train_and_evaluate 0 0% 99.60% 6062 100% <unknown> 0 0% 99.60% 6062 100% __init__ 0 0% 99.60% 3872 63.87% _call_wrapped_method 0 0% 99.60% 6062 100% _run_and_get_tests_result 0 0% 99.60% 6062 100% _run_code_in_main 0 0% 99.60% 6062 100% _run_in_app 0 0% 99.60% 6062 100% _run_main 0 0% 99.60% 3872 63.87% apply 0 0% 99.60% 161 2.66% apply_updates 0 0% 99.60% 6062 100% main 0 0% 99.60% 6062 100% main_function 0 0% 99.60% 6062 100% run 0 0% 99.60% 6062 100% runTests 0 0% 99.60% 6062 100% run_filename_as_main 0 0% 99.60% 6062 100% run_tests 0 0% 99.60% 3872 63.87% scope_fn 0 0% 99.60% 6062 100% test_train_and_evaluate 0 0% 99.60% 1159 19.12% update_fn 0 0% 99.60% 3872 63.87% wrapped_fn 0 0% 99.60% 3872 63.87% wrapped_module_method 0 0% 99.60% 3872 63.87% wrapper ``` I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization. ``` pprof --tagleaf=primitive --http=:8080 myprof ``` [XLA:Python] Add helpers to Traceback and for working with pprof profiles. * Define hash and equality operators on Tracebacks. * Add functions for converting JSON to and from pprof profile protocol buffers. * Add a helper method that exposes PyCode_Addr2Line to Python. PiperOrigin-RevId: 421395346
2022-01-12 14:27:17 -08:00
if is_user_filename(x.file_name))
def user_frame(source_info: SourceInfo) -> Optional[Frame]:
return next(user_frames(source_info), None)
def summarize(source_info: SourceInfo, num_frames=1) -> str:
frames = itertools.islice(user_frames(source_info), num_frames)
frame_strs = [f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
if frame else "unknown" for frame in frames]
return '\n'.join(reversed(frame_strs))
class _SourceInfoContext(threading.local):
context: SourceInfo
def __init__(self):
self.context = new_source_info()
_source_info_context = _SourceInfoContext()
def current() -> SourceInfo:
context = _source_info_context.context
if not context.traceback:
return context.replace(traceback=xla_client.Traceback.get_traceback())
return context
class JaxStackTraceBeforeTransformation(Exception): pass
_message = (
'The preceding stack trace is the source of the JAX operation that, once '
'transformed by JAX, triggered the following exception.\n'
'\n--------------------')
def has_user_context(e):
while e is not None:
if isinstance(e, JaxStackTraceBeforeTransformation):
return True
e = e.__cause__
return False
@contextlib.contextmanager
def user_context(c: Optional[Traceback]):
prev = _source_info_context.context
_source_info_context.context = _source_info_context.context.replace(traceback=c)
filtered_tb = None
try:
yield
except Exception as e:
if c is None or has_user_context(e):
raise
filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
if filtered_tb:
msg = traceback_util.format_exception_only(e)
msg = f'{msg}\n\n{_message}'
exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
exp.__context__ = e.__context__
exp.__cause__ = e.__cause__
exp.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = exp
raise
finally:
_source_info_context.context = prev
del filtered_tb