2020-11-04 11:54:01 -08:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import contextlib
|
2021-10-28 11:06:58 -07:00
|
|
|
import dataclasses
|
2022-02-07 14:40:11 -08:00
|
|
|
import functools
|
2021-01-05 14:52:54 -08:00
|
|
|
import itertools
|
2020-11-04 11:54:01 -08:00
|
|
|
import os.path
|
|
|
|
import threading
|
2022-02-08 16:17:09 -08:00
|
|
|
import types
|
2021-10-28 11:06:58 -07:00
|
|
|
from typing import Optional, Iterator, NamedTuple, Union, Tuple
|
2020-11-04 11:54:01 -08:00
|
|
|
|
2020-11-18 10:08:18 -05:00
|
|
|
import jax.version
|
2022-03-04 10:25:22 -05:00
|
|
|
from jax._src.lib import xla_client
|
2020-11-04 11:54:01 -08:00
|
|
|
|
|
|
|
from jax._src import traceback_util
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
|
|
|
2021-07-26 13:44:57 +01:00
|
|
|
Traceback = xla_client.Traceback
|
2022-02-08 16:17:09 -08:00
|
|
|
|
|
|
|
class Frame(NamedTuple):
|
|
|
|
file_name: str
|
|
|
|
function_name: str
|
|
|
|
line_num: int
|
|
|
|
|
2020-11-04 11:54:01 -08:00
|
|
|
|
2021-01-05 14:52:54 -08:00
|
|
|
_exclude_paths = [os.path.dirname(jax.version.__file__)]
|
2020-11-04 11:54:01 -08:00
|
|
|
|
2021-01-05 14:52:54 -08:00
|
|
|
def register_exclusion(path):
|
|
|
|
_exclude_paths.append(path)
|
|
|
|
|
2021-10-28 11:06:58 -07:00
|
|
|
class Scope(NamedTuple):
|
|
|
|
name: str
|
|
|
|
|
|
|
|
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
|
|
|
|
return (self.name, *stack)
|
|
|
|
|
|
|
|
class Transform(NamedTuple):
|
|
|
|
name: str
|
|
|
|
|
|
|
|
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
|
2022-06-10 03:49:58 -07:00
|
|
|
if stack:
|
|
|
|
return (f'{self.name}({stack[0]})', *stack[1:])
|
|
|
|
else:
|
|
|
|
return ()
|
2021-10-28 11:06:58 -07:00
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class NameStack:
|
|
|
|
stack: Tuple[Union[Scope, Transform], ...] = ()
|
|
|
|
|
|
|
|
def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack':
|
|
|
|
if not isinstance(name, tuple):
|
|
|
|
name = (name,)
|
|
|
|
scopes = tuple(map(Scope, name))
|
|
|
|
return NameStack(self.stack + scopes)
|
|
|
|
|
|
|
|
def wrap_name(self, name: str) -> str:
|
|
|
|
if not self.stack:
|
|
|
|
return name
|
|
|
|
return f'{str(self)}/{name}'
|
|
|
|
|
|
|
|
def transform(self, transform_name: str) -> 'NameStack':
|
|
|
|
return NameStack((*self.stack, Transform(transform_name)))
|
|
|
|
|
|
|
|
def __getitem__(self, idx) -> 'NameStack':
|
|
|
|
return NameStack(self.stack[idx])
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.stack)
|
|
|
|
|
|
|
|
def __add__(self, other: 'NameStack') -> 'NameStack':
|
|
|
|
return NameStack(self.stack + other.stack)
|
|
|
|
|
|
|
|
def __radd__(self, other: 'NameStack') -> 'NameStack':
|
|
|
|
return NameStack(other.stack + self.stack)
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
scope: Tuple[str, ...] = ()
|
|
|
|
for elem in self.stack[::-1]:
|
|
|
|
scope = elem.wrap(scope)
|
|
|
|
return '/'.join(scope)
|
|
|
|
|
2021-10-29 15:49:31 -07:00
|
|
|
class SourceInfo(NamedTuple):
|
|
|
|
traceback: Optional[Traceback]
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack: NameStack
|
2021-10-29 15:49:31 -07:00
|
|
|
|
2021-10-28 11:06:58 -07:00
|
|
|
def replace(self, *, traceback: Optional[Traceback] = None,
|
|
|
|
name_stack: Optional[NameStack] = None) -> 'SourceInfo':
|
2021-10-29 15:49:31 -07:00
|
|
|
traceback = traceback or self.traceback
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self.name_stack if name_stack is None else name_stack
|
|
|
|
return self._replace(traceback=traceback, name_stack=name_stack)
|
2021-10-29 15:49:31 -07:00
|
|
|
|
|
|
|
def new_source_info() -> SourceInfo:
|
2021-10-28 11:06:58 -07:00
|
|
|
return SourceInfo(None, NameStack())
|
2021-10-29 15:49:31 -07:00
|
|
|
|
2022-01-12 14:27:17 -08:00
|
|
|
def is_user_filename(filename: str) -> bool:
|
2020-11-04 11:54:01 -08:00
|
|
|
"""Heuristic that guesses the identity of the user's code in a stack trace."""
|
2022-01-12 14:27:17 -08:00
|
|
|
return (filename.endswith("_test.py") or
|
|
|
|
not any(filename.startswith(p) for p in _exclude_paths))
|
|
|
|
|
2022-02-08 16:17:09 -08:00
|
|
|
def _raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
|
|
|
|
return Frame(file_name=code.co_filename,
|
|
|
|
function_name=code.co_name,
|
|
|
|
line_num=xla_client.Traceback.code_addr2line(code, lasti))
|
|
|
|
|
2022-01-12 14:27:17 -08:00
|
|
|
def user_frames(source_info: SourceInfo) -> Iterator[Frame]:
|
2022-02-08 16:17:09 -08:00
|
|
|
"""Iterator over the user's frames, filtering jax-internal frames."""
|
2020-11-04 11:54:01 -08:00
|
|
|
# Guess the user's frame is the innermost frame not in the jax source tree
|
2021-01-19 15:01:30 -08:00
|
|
|
# We don't use traceback_util.path_starts_with because that incurs filesystem
|
|
|
|
# access, which may be slow; we call this function when e.g. adding source
|
|
|
|
# provenance annotations to XLA lowerings, so we don't want to incur the cost.
|
[jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-05-25 13:33:35 +02:00
|
|
|
# We consider files that end with _test.py as user frames, to allow testing
|
|
|
|
# this mechanism from tests.
|
2021-10-29 15:49:31 -07:00
|
|
|
traceback = source_info.traceback
|
2022-03-04 10:25:22 -05:00
|
|
|
code, lasti = traceback.raw_frames() if traceback else ([], [])
|
|
|
|
return (_raw_frame_to_frame(code[i], lasti[i]) for i in range(len(code)) # type: ignore
|
|
|
|
if is_user_filename(code[i].co_filename))
|
2020-11-04 11:54:01 -08:00
|
|
|
|
2022-02-07 14:40:11 -08:00
|
|
|
@functools.lru_cache(maxsize=64)
|
2021-10-29 15:49:31 -07:00
|
|
|
def user_frame(source_info: SourceInfo) -> Optional[Frame]:
|
2021-01-05 14:52:54 -08:00
|
|
|
return next(user_frames(source_info), None)
|
2020-11-04 11:54:01 -08:00
|
|
|
|
2021-10-29 15:49:31 -07:00
|
|
|
def summarize(source_info: SourceInfo, num_frames=1) -> str:
|
2021-01-05 14:52:54 -08:00
|
|
|
frames = itertools.islice(user_frames(source_info), num_frames)
|
|
|
|
frame_strs = [f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
|
|
|
|
if frame else "unknown" for frame in frames]
|
|
|
|
return '\n'.join(reversed(frame_strs))
|
2020-11-04 11:54:01 -08:00
|
|
|
|
|
|
|
class _SourceInfoContext(threading.local):
|
2021-10-29 15:49:31 -07:00
|
|
|
context: SourceInfo
|
2020-11-04 11:54:01 -08:00
|
|
|
|
|
|
|
def __init__(self):
|
2021-10-29 15:49:31 -07:00
|
|
|
self.context = new_source_info()
|
2020-11-04 11:54:01 -08:00
|
|
|
|
|
|
|
_source_info_context = _SourceInfoContext()
|
|
|
|
|
2021-10-29 15:49:31 -07:00
|
|
|
def current() -> SourceInfo:
|
2021-10-28 11:06:58 -07:00
|
|
|
source_info = _source_info_context.context
|
|
|
|
if not source_info.traceback:
|
|
|
|
source_info = source_info.replace(traceback=xla_client.Traceback.get_traceback())
|
|
|
|
return source_info
|
2020-11-04 11:54:01 -08:00
|
|
|
|
2021-05-03 07:48:18 -07:00
|
|
|
class JaxStackTraceBeforeTransformation(Exception): pass
|
|
|
|
|
|
|
|
_message = (
|
|
|
|
'The preceding stack trace is the source of the JAX operation that, once '
|
|
|
|
'transformed by JAX, triggered the following exception.\n'
|
|
|
|
'\n--------------------')
|
|
|
|
|
|
|
|
def has_user_context(e):
|
|
|
|
while e is not None:
|
|
|
|
if isinstance(e, JaxStackTraceBeforeTransformation):
|
|
|
|
return True
|
|
|
|
e = e.__cause__
|
|
|
|
return False
|
|
|
|
|
2020-11-04 11:54:01 -08:00
|
|
|
@contextlib.contextmanager
|
2021-10-28 11:06:58 -07:00
|
|
|
def user_context(c: Optional[Traceback], *, name_stack: Optional[NameStack] = None):
|
2020-11-04 11:54:01 -08:00
|
|
|
prev = _source_info_context.context
|
2021-10-28 11:06:58 -07:00
|
|
|
_source_info_context.context = _source_info_context.context.replace(
|
|
|
|
traceback=c, name_stack=name_stack)
|
2021-05-03 07:48:18 -07:00
|
|
|
filtered_tb = None
|
2020-11-04 11:54:01 -08:00
|
|
|
try:
|
|
|
|
yield
|
2021-05-03 07:48:18 -07:00
|
|
|
except Exception as e:
|
|
|
|
if c is None or has_user_context(e):
|
|
|
|
raise
|
|
|
|
filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
|
|
|
|
if filtered_tb:
|
|
|
|
msg = traceback_util.format_exception_only(e)
|
|
|
|
msg = f'{msg}\n\n{_message}'
|
2021-10-29 15:49:31 -07:00
|
|
|
exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
|
|
|
|
exp.__context__ = e.__context__
|
|
|
|
exp.__cause__ = e.__cause__
|
|
|
|
exp.__suppress_context__ = e.__suppress_context__
|
2021-05-03 07:48:18 -07:00
|
|
|
e.__context__ = None
|
2021-10-29 15:49:31 -07:00
|
|
|
e.__cause__ = exp
|
2021-05-03 07:48:18 -07:00
|
|
|
raise
|
2020-11-04 11:54:01 -08:00
|
|
|
finally:
|
|
|
|
_source_info_context.context = prev
|
2021-05-03 07:48:18 -07:00
|
|
|
del filtered_tb
|
2021-10-28 11:06:58 -07:00
|
|
|
|
|
|
|
def current_name_stack() -> NameStack:
|
|
|
|
return _source_info_context.context.name_stack
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def extend_name_stack(name: str) -> Iterator[NameStack]:
|
|
|
|
prev_context = _source_info_context.context
|
|
|
|
curr_name_stack = prev_context.name_stack
|
|
|
|
new_context = prev_context.replace(name_stack=curr_name_stack.extend(name))
|
|
|
|
_source_info_context.context = new_context
|
|
|
|
try:
|
|
|
|
yield _source_info_context.context.name_stack
|
|
|
|
finally:
|
|
|
|
_source_info_context.context = prev_context
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def set_name_stack(name_stack: NameStack) -> Iterator[None]:
|
|
|
|
prev_context = _source_info_context.context
|
|
|
|
new_context = prev_context.replace(name_stack=name_stack)
|
|
|
|
_source_info_context.context = new_context
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
_source_info_context.context = prev_context
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def reset_name_stack() -> Iterator[None]:
|
|
|
|
with set_name_stack(NameStack()):
|
|
|
|
yield
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def transform_name_stack(name: str) -> Iterator[NameStack]:
|
|
|
|
prev_context = _source_info_context.context
|
|
|
|
curr_name_stack = prev_context.name_stack
|
|
|
|
new_context = prev_context.replace(name_stack=curr_name_stack.transform(name))
|
|
|
|
_source_info_context.context = new_context
|
|
|
|
try:
|
|
|
|
yield _source_info_context.context.name_stack
|
|
|
|
finally:
|
|
|
|
_source_info_context.context = prev_context
|