Small optimizations to mlir._traceback_to_location.

Cache construction of complete MLIR locations for each frame.

PiperOrigin-RevId: 606587073
This commit is contained in:
Peter Hawkins 2024-02-13 05:28:27 -08:00 committed by jax authors
parent 0ddf37145e
commit 2a8634b602

View File

@ -25,6 +25,7 @@ import itertools
import operator
import os
import re
import types
import typing
from typing import Any, Callable, NamedTuple, Protocol, Union
import warnings
@ -363,35 +364,31 @@ def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
ctx.traceback_caches.is_user_file_cache[file_name] = result
return result
def _raw_frame_to_frame(ctx: ModuleContext,
code: source_info_util.types.CodeType, lasti: int):
key = (code.co_filename, lasti)
if key in ctx.traceback_caches.raw_frame_to_frame_cache:
return ctx.traceback_caches.raw_frame_to_frame_cache[key]
frame = source_info_util.raw_frame_to_frame(code, lasti)
ctx.traceback_caches.raw_frame_to_frame_cache[key] = frame
return frame
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
frame_locs = []
frames_limit = config.traceback_in_locations_limit.value
if frames_limit == 0:
return ir.Location.unknown()
frames_limit = frames_limit if frames_limit >= 0 else 1000
for code, lasti in zip(*tb.raw_frames()):
codes, lastis = tb.raw_frames()
for i, code in enumerate(codes):
if not _is_user_file(ctx, code.co_filename):
continue
frame = _raw_frame_to_frame(ctx, code, lasti)
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
name_loc = ir.Location.name(frame.function_name, childLoc=file_loc)
frame_locs.append(name_loc)
if frames_limit > 0 and len(frame_locs) >= frames_limit:
lasti = lastis[i]
code_lasti = code, lasti
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
if loc is None:
frame = source_info_util.raw_frame_to_frame(code, lasti)
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
ctx.traceback_caches.location_cache[code_lasti] = loc
frame_locs.append(loc)
if len(frame_locs) >= frames_limit:
break
if len(frame_locs) == 0:
@ -578,14 +575,14 @@ class LoweringParameters:
@dataclasses.dataclass
class TracebackCaches:
location_cache: dict[tuple[types.CodeType, int], ir.Location]
canonical_name_cache: dict[str, str]
is_user_file_cache: dict[str, bool]
raw_frame_to_frame_cache: dict[tuple[str, int], source_info_util.Frame]
def __init__(self):
self.location_cache = {}
self.canonical_name_cache = {}
self.is_user_file_cache = {}
self.raw_frame_to_frame_cache = {}
@dataclasses.dataclass
class ModuleContext:
@ -1315,9 +1312,9 @@ def lower_jaxpr_to_fun(
output_ids = util.unflatten(list(range(len(flat_output_types))),
map(len, output_types))
aliases: list[int | None] = []
for types, alias in zip(input_types, input_output_aliases):
for itypes, alias in zip(input_types, input_output_aliases):
if alias is None:
aliases.extend([None] * len(types))
aliases.extend([None] * len(itypes))
else:
aliases.extend(output_ids[alias])