[JAX] Move source_info_util into jax._src.

TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.

PiperOrigin-RevId: 340698612
This commit is contained in:
Peter Hawkins 2020-11-04 11:54:01 -08:00 committed by jax authors
parent 644de24977
commit 7efc1dbc94
10 changed files with 74 additions and 57 deletions

View File

@ -31,7 +31,7 @@ import jax
from jax import api
from jax import core
from jax import dtypes
from jax import source_info_util
from jax._src import source_info_util
from jax import util
from jax._src.lax import lax
from jax import linear_util as lu

View File

@ -23,7 +23,7 @@ import numpy as np
from jax import core
from jax import dtypes
from jax import tree_util
from jax import source_info_util
from jax._src import source_info_util
from . import lax
from jax.abstract_arrays import ShapedArray, raise_to_shaped
from jax.interpreters import ad

View File

@ -0,0 +1,63 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os.path
import threading
from typing import Any, Optional
from jax.lib import xla_client
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
Traceback = Any # xla_client.Traceback
Frame = Any # xla_client.Traceback::Frame
_jax_path = os.path.dirname(__file__)
def user_frame(source_info: Optional[Traceback]) -> Optional[Frame]:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
# Guess the user's frame is the innermost frame not in the jax source tree
return next((x for x in (source_info.frames if source_info else [])
if not x.file_name.startswith(_jax_path)), None)
def summarize(source_info: Optional[Traceback]) -> str:
frame = user_frame(source_info)
return (f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
if frame else "unknown")
class _SourceInfoContext(threading.local):
context: Optional[Traceback]
def __init__(self):
self.context = None
_source_info_context = _SourceInfoContext()
def current() -> Optional[Traceback]:
return _source_info_context.context or xla_client.Traceback.get_traceback()
@contextlib.contextmanager
def user_context(c):
prev = _source_info_context.context
_source_info_context.context = c or _source_info_context.context
try:
yield
finally:
_source_info_context.context = prev

View File

@ -32,7 +32,7 @@ from . import dtypes
from .config import FLAGS, config
from . import linear_util as lu
from . import source_info_util
from jax._src import source_info_util
from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
from .pprint_util import pp, vcat, PrettyPrint

View File

@ -180,7 +180,7 @@ from jax.lib import pytree
from jax.interpreters import ad, xla, batching, masking
from jax.interpreters import partial_eval as pe
from jax import pprint_util as ppu
from jax import source_info_util
from jax._src import source_info_util
from jax import util
from jaxlib import xla_client
from jaxlib import xla_extension

View File

@ -30,7 +30,7 @@ from ..tree_util import register_pytree_node
from .. import linear_util as lu
from ..api_util import flatten_fun, flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten, Partial
from .. import source_info_util
from jax._src import source_info_util
zip = safe_zip
map = safe_map

View File

@ -32,7 +32,7 @@ from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
dropvar)
from .. import source_info_util
from jax._src import source_info_util
from ..config import config
map = safe_map

View File

@ -28,7 +28,7 @@ from .. import ad_util
from .. import dtypes
from .. import lazy
from .. import linear_util as lu
from .. import source_info_util
from jax._src import source_info_util
from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
make_shaped_array, array_types, raise_to_shaped,
abstract_token)

View File

@ -17,7 +17,8 @@
import collections
from typing import Any, Callable, Dict, List, Optional
from jax import core, source_info_util, util
from jax import core, util
from jax._src import source_info_util
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

View File

@ -12,52 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os.path
import threading
from typing import Any, Optional
from .lib import xla_client
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
Traceback = Any # xla_client.Traceback
Frame = Any # xla_client.Traceback::Frame
_jax_path = os.path.dirname(__file__)
def user_frame(source_info: Optional[Traceback]) -> Optional[Frame]:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
# Guess the user's frame is the innermost frame not in the jax source tree
return next((x for x in (source_info.frames if source_info else [])
if not x.file_name.startswith(_jax_path)), None)
def summarize(source_info: Optional[Traceback]) -> str:
frame = user_frame(source_info)
return (f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
if frame else "unknown")
class _SourceInfoContext(threading.local):
context: Optional[Traceback]
def __init__(self):
self.context = None
_source_info_context = _SourceInfoContext()
def current() -> Optional[Traceback]:
return _source_info_context.context or xla_client.Traceback.get_traceback()
@contextlib.contextmanager
def user_context(c):
prev = _source_info_context.context
_source_info_context.context = c or _source_info_context.context
try:
yield
finally:
_source_info_context.context = prev
# flake8: noqa: F401
from jax._src.source_info_util import current