mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Include column information in Python locations under Python 3.11.
https://peps.python.org/pep-0657/ means that we now have richer context information, which we can propagate where we use it, for example to the MHLO location in this example: ``` In [2]: jax.jit(lambda x: x + 2).lower(7).compiler_ir().operation.print(enable_debug_info=True) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) module @jit__lambda_ { func.func public @main(%arg0: tensor<i32> loc(unknown)) -> tensor<i32> { %0 = mhlo.constant dense<2> : tensor<i32> loc(#loc0) %1 = mhlo.add %arg0, %0 : tensor<i32> loc(#loc1) return %1 : tensor<i32> loc(#loc0) } loc(#loc0) } loc(#loc0) #loc1 = loc("jit(<lambda>)/jit(main)/add"("<ipython-input-2-525e569b8960>":1:18)) ```
This commit is contained in:
parent
38a7582923
commit
ec5bec6157
@ -33,7 +33,10 @@ Traceback = xla_client.Traceback
|
||||
class Frame(NamedTuple):
|
||||
file_name: str
|
||||
function_name: str
|
||||
line_num: int
|
||||
start_line: int
|
||||
start_column: int
|
||||
end_line: int
|
||||
end_column: int
|
||||
|
||||
|
||||
_exclude_paths = [os.path.dirname(jax.version.__file__)]
|
||||
@ -110,10 +113,21 @@ def is_user_filename(filename: str) -> bool:
|
||||
return (filename.endswith("_test.py") or
|
||||
not any(filename.startswith(p) for p in _exclude_paths))
|
||||
|
||||
def _raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
|
||||
return Frame(file_name=code.co_filename,
|
||||
function_name=code.co_name,
|
||||
line_num=xla_client.Traceback.code_addr2line(code, lasti))
|
||||
if hasattr(xla_client.Traceback, "code_addr2location"):
|
||||
# Python 3.11+
|
||||
def _raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
|
||||
loc = xla_client.Traceback.code_addr2location(code, lasti)
|
||||
start_line, start_column, end_line, end_column = loc
|
||||
return Frame(file_name=code.co_filename,
|
||||
function_name=code.co_name,
|
||||
start_line=start_line, start_column=start_column,
|
||||
end_line=end_line, end_column=end_column)
|
||||
else:
|
||||
def _raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
|
||||
return Frame(file_name=code.co_filename,
|
||||
function_name=code.co_name,
|
||||
start_line=xla_client.Traceback.code_addr2line(code, lasti),
|
||||
start_column=0, end_line=0, end_column=0)
|
||||
|
||||
def user_frames(source_info: SourceInfo) -> Iterator[Frame]:
|
||||
"""Iterator over the user's frames, filtering jax-internal frames."""
|
||||
@ -132,10 +146,17 @@ def user_frames(source_info: SourceInfo) -> Iterator[Frame]:
|
||||
def user_frame(source_info: SourceInfo) -> Optional[Frame]:
|
||||
return next(user_frames(source_info), None)
|
||||
|
||||
def _summarize_frame(frame: Frame) -> str:
|
||||
if frame.start_column != 0:
|
||||
return (f"{frame.file_name}:{frame.start_line}:{frame.start_column} "
|
||||
f"({frame.function_name})")
|
||||
else:
|
||||
return f"{frame.file_name}:{frame.start_line} ({frame.function_name})"
|
||||
|
||||
def summarize(source_info: SourceInfo, num_frames=1) -> str:
|
||||
frames = itertools.islice(user_frames(source_info), num_frames)
|
||||
frame_strs = [f"{frame.file_name}:{frame.line_num} ({frame.function_name})"
|
||||
if frame else "unknown" for frame in frames]
|
||||
frame_strs = [_summarize_frame(frame) if frame else "unknown"
|
||||
for frame in frames]
|
||||
return '\n'.join(reversed(frame_strs))
|
||||
|
||||
class _SourceInfoContext(threading.local):
|
||||
|
@ -1048,7 +1048,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_simple, x,
|
||||
[tf_test_util.OpMetadataGraph(tf_type="Sin",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 2,
|
||||
source_line=user_frame.start_line + 2,
|
||||
op_name="jax2tf(f_simple)/sin",
|
||||
op_type="sin")
|
||||
]
|
||||
@ -1072,17 +1072,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_caller, x,
|
||||
[tf_test_util.OpMetadataGraph(tf_type="Tanh",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 4,
|
||||
source_line=user_frame.start_line + 4,
|
||||
op_name="jax2tf(f_caller)/tanh",
|
||||
op_type="tanh"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="Cos",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 2,
|
||||
source_line=user_frame.start_line + 2,
|
||||
op_name="jax2tf(f_caller)/jit(f_callee)/cos",
|
||||
op_type="cos"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="Sin",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 6,
|
||||
source_line=user_frame.start_line + 6,
|
||||
op_name="jax2tf(f_caller)/sin",
|
||||
op_type="sin"),
|
||||
]
|
||||
@ -1106,17 +1106,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_caller, x,
|
||||
[tf_test_util.OpMetadataGraph(tf_type="Tanh",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 4,
|
||||
source_line=user_frame.start_line + 4,
|
||||
op_name="jax2tf(f_caller)/tanh",
|
||||
op_type="tanh"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="Cos",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 2,
|
||||
source_line=user_frame.start_line + 2,
|
||||
op_name="jax2tf(f_caller)/named(callee)/cos",
|
||||
op_type="cos"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="Sin",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 6,
|
||||
source_line=user_frame.start_line + 6,
|
||||
op_name="jax2tf(f_caller)/sin",
|
||||
op_type="sin"),
|
||||
]
|
||||
@ -1147,17 +1147,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_while_cond, x,
|
||||
[tf_test_util.OpMetadataGraph(tf_type="Cos",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 5,
|
||||
source_line=user_frame.start_line + 5,
|
||||
op_name="jax2tf(f_while_cond)/while/body/cos",
|
||||
op_type="cos"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="Sin",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 7,
|
||||
source_line=user_frame.start_line + 7,
|
||||
op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
|
||||
op_type="sin"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="FloorMod",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 6,
|
||||
source_line=user_frame.start_line + 6,
|
||||
op_name="jax2tf(f_while_cond)/while/body/rem",
|
||||
op_type="rem"),
|
||||
]
|
||||
@ -1192,12 +1192,12 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_while, x,
|
||||
[tf_test_util.OpMetadataGraph(tf_type="Sin",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 4,
|
||||
source_line=user_frame.start_line + 4,
|
||||
op_name="jax2tf(f_while)/while/body/sin",
|
||||
op_type="sin"),
|
||||
tf_test_util.OpMetadataGraph(tf_type="LessEqual",
|
||||
source_file=__file__,
|
||||
source_line=user_frame.line_num + 8,
|
||||
source_line=user_frame.start_line + 8,
|
||||
op_name="jax2tf(f_while)/while/body_pred/le",
|
||||
op_type="le"),
|
||||
]
|
||||
|
@ -323,7 +323,7 @@ def _source_info_to_location(
|
||||
loc = ir.Location.unknown()
|
||||
else:
|
||||
loc = ir.Location.file(xla._get_canonical_source_file(frame),
|
||||
frame.line_num, 1)
|
||||
frame.start_line, frame.start_column)
|
||||
loc = ir.Location.name(eqn_str, childLoc=loc)
|
||||
# TODO(phawkins): also include primitive.name as the operator type.
|
||||
return loc
|
||||
|
@ -115,7 +115,7 @@ def make_op_metadata(primitive: core.Primitive,
|
||||
op_type=primitive.name,
|
||||
op_name=eqn_str,
|
||||
source_file=_get_canonical_source_file(frame) if frame else None,
|
||||
source_line=frame.line_num if frame else None)
|
||||
source_line=frame.start_line if frame else None)
|
||||
|
||||
# Utilities
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user