mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Small optimizations to mlir._traceback_to_location.
Cache construction of complete MLIR locations for each frame. PiperOrigin-RevId: 606587073
This commit is contained in:
parent
0ddf37145e
commit
2a8634b602
@ -25,6 +25,7 @@ import itertools
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
import typing
|
||||
from typing import Any, Callable, NamedTuple, Protocol, Union
|
||||
import warnings
|
||||
@ -363,35 +364,31 @@ def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
|
||||
ctx.traceback_caches.is_user_file_cache[file_name] = result
|
||||
return result
|
||||
|
||||
def _raw_frame_to_frame(ctx: ModuleContext,
|
||||
code: source_info_util.types.CodeType, lasti: int):
|
||||
key = (code.co_filename, lasti)
|
||||
if key in ctx.traceback_caches.raw_frame_to_frame_cache:
|
||||
return ctx.traceback_caches.raw_frame_to_frame_cache[key]
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
ctx.traceback_caches.raw_frame_to_frame_cache[key] = frame
|
||||
return frame
|
||||
|
||||
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
"""Converts a full traceback to a callsite() MLIR location."""
|
||||
frame_locs = []
|
||||
frames_limit = config.traceback_in_locations_limit.value
|
||||
if frames_limit == 0:
|
||||
return ir.Location.unknown()
|
||||
frames_limit = frames_limit if frames_limit >= 0 else 1000
|
||||
|
||||
for code, lasti in zip(*tb.raw_frames()):
|
||||
codes, lastis = tb.raw_frames()
|
||||
for i, code in enumerate(codes):
|
||||
if not _is_user_file(ctx, code.co_filename):
|
||||
continue
|
||||
|
||||
frame = _raw_frame_to_frame(ctx, code, lasti)
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
name_loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
||||
frame_locs.append(name_loc)
|
||||
if frames_limit > 0 and len(frame_locs) >= frames_limit:
|
||||
lasti = lastis[i]
|
||||
code_lasti = code, lasti
|
||||
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
|
||||
if loc is None:
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
||||
ctx.traceback_caches.location_cache[code_lasti] = loc
|
||||
frame_locs.append(loc)
|
||||
if len(frame_locs) >= frames_limit:
|
||||
break
|
||||
|
||||
if len(frame_locs) == 0:
|
||||
@ -578,14 +575,14 @@ class LoweringParameters:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TracebackCaches:
|
||||
location_cache: dict[tuple[types.CodeType, int], ir.Location]
|
||||
canonical_name_cache: dict[str, str]
|
||||
is_user_file_cache: dict[str, bool]
|
||||
raw_frame_to_frame_cache: dict[tuple[str, int], source_info_util.Frame]
|
||||
|
||||
def __init__(self):
|
||||
self.location_cache = {}
|
||||
self.canonical_name_cache = {}
|
||||
self.is_user_file_cache = {}
|
||||
self.raw_frame_to_frame_cache = {}
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
@ -1315,9 +1312,9 @@ def lower_jaxpr_to_fun(
|
||||
output_ids = util.unflatten(list(range(len(flat_output_types))),
|
||||
map(len, output_types))
|
||||
aliases: list[int | None] = []
|
||||
for types, alias in zip(input_types, input_output_aliases):
|
||||
for itypes, alias in zip(input_types, input_output_aliases):
|
||||
if alias is None:
|
||||
aliases.extend([None] * len(types))
|
||||
aliases.extend([None] * len(itypes))
|
||||
else:
|
||||
aliases.extend(output_ids[alias])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user