mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Move source_info_util into jax._src.
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP. PiperOrigin-RevId: 340698612
This commit is contained in:
parent
644de24977
commit
7efc1dbc94
@ -31,7 +31,7 @@ import jax
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import source_info_util
|
||||
from jax._src import source_info_util
|
||||
from jax import util
|
||||
from jax._src.lax import lax
|
||||
from jax import linear_util as lu
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import tree_util
|
||||
from jax import source_info_util
|
||||
from jax._src import source_info_util
|
||||
from . import lax
|
||||
from jax.abstract_arrays import ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
|
63
jax/_src/source_info_util.py
Normal file
63
jax/_src/source_info_util.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os.path
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
from jax.lib import xla_client
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
Traceback = Any # xla_client.Traceback
|
||||
Frame = Any # xla_client.Traceback::Frame
|
||||
|
||||
_jax_path = os.path.dirname(__file__)
|
||||
|
||||
def user_frame(source_info: Optional[Traceback]) -> Optional[Frame]:
|
||||
"""Heuristic that guesses the identity of the user's code in a stack trace."""
|
||||
# Guess the user's frame is the innermost frame not in the jax source tree
|
||||
return next((x for x in (source_info.frames if source_info else [])
|
||||
if not x.file_name.startswith(_jax_path)), None)
|
||||
|
||||
|
||||
def summarize(source_info: Optional[Traceback]) -> str:
|
||||
frame = user_frame(source_info)
|
||||
return (f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
|
||||
if frame else "unknown")
|
||||
|
||||
|
||||
class _SourceInfoContext(threading.local):
|
||||
context: Optional[Traceback]
|
||||
|
||||
def __init__(self):
|
||||
self.context = None
|
||||
|
||||
_source_info_context = _SourceInfoContext()
|
||||
|
||||
|
||||
def current() -> Optional[Traceback]:
|
||||
return _source_info_context.context or xla_client.Traceback.get_traceback()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def user_context(c):
|
||||
prev = _source_info_context.context
|
||||
_source_info_context.context = c or _source_info_context.context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_source_info_context.context = prev
|
@ -32,7 +32,7 @@ from . import dtypes
|
||||
from .config import FLAGS, config
|
||||
from . import linear_util as lu
|
||||
|
||||
from . import source_info_util
|
||||
from jax._src import source_info_util
|
||||
from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
|
||||
from .pprint_util import pp, vcat, PrettyPrint
|
||||
|
||||
|
@ -180,7 +180,7 @@ from jax.lib import pytree
|
||||
from jax.interpreters import ad, xla, batching, masking
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax import pprint_util as ppu
|
||||
from jax import source_info_util
|
||||
from jax._src import source_info_util
|
||||
from jax import util
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import xla_extension
|
||||
|
@ -30,7 +30,7 @@ from ..tree_util import register_pytree_node
|
||||
from .. import linear_util as lu
|
||||
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten, Partial
|
||||
from .. import source_info_util
|
||||
from jax._src import source_info_util
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
|
@ -32,7 +32,7 @@ from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
|
||||
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
|
||||
dropvar)
|
||||
from .. import source_info_util
|
||||
from jax._src import source_info_util
|
||||
from ..config import config
|
||||
|
||||
map = safe_map
|
||||
|
@ -28,7 +28,7 @@ from .. import ad_util
|
||||
from .. import dtypes
|
||||
from .. import lazy
|
||||
from .. import linear_util as lu
|
||||
from .. import source_info_util
|
||||
from jax._src import source_info_util
|
||||
from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
|
||||
make_shaped_array, array_types, raise_to_shaped,
|
||||
abstract_token)
|
||||
|
@ -17,7 +17,8 @@
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from jax import core, source_info_util, util
|
||||
from jax import core, util
|
||||
from jax._src import source_info_util
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
@ -12,52 +12,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os.path
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
from .lib import xla_client
|
||||
|
||||
from ._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
Traceback = Any # xla_client.Traceback
|
||||
Frame = Any # xla_client.Traceback::Frame
|
||||
|
||||
_jax_path = os.path.dirname(__file__)
|
||||
|
||||
def user_frame(source_info: Optional[Traceback]) -> Optional[Frame]:
|
||||
"""Heuristic that guesses the identity of the user's code in a stack trace."""
|
||||
# Guess the user's frame is the innermost frame not in the jax source tree
|
||||
return next((x for x in (source_info.frames if source_info else [])
|
||||
if not x.file_name.startswith(_jax_path)), None)
|
||||
|
||||
|
||||
def summarize(source_info: Optional[Traceback]) -> str:
|
||||
frame = user_frame(source_info)
|
||||
return (f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
|
||||
if frame else "unknown")
|
||||
|
||||
|
||||
class _SourceInfoContext(threading.local):
|
||||
context: Optional[Traceback]
|
||||
|
||||
def __init__(self):
|
||||
self.context = None
|
||||
|
||||
_source_info_context = _SourceInfoContext()
|
||||
|
||||
|
||||
def current() -> Optional[Traceback]:
|
||||
return _source_info_context.context or xla_client.Traceback.get_traceback()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def user_context(c):
|
||||
prev = _source_info_context.context
|
||||
_source_info_context.context = c or _source_info_context.context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_source_info_context.context = prev
|
||||
# flake8: noqa: F401
|
||||
from jax._src.source_info_util import current
|
||||
|
Loading…
x
Reference in New Issue
Block a user