Add a cache mapping complete traceback objects to MLIR IR locations.

This turns out to be profitable in at least one benchmark.

PiperOrigin-RevId: 609726977
This commit is contained in:
Peter Hawkins 2024-02-23 07:51:09 -08:00 committed by jax authors
parent 9088e5c6f6
commit 0080b36a3d

View File

@ -360,6 +360,10 @@ def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
loc = ctx.traceback_caches.traceback_cache.get(tb, None)
if loc is not None:
return loc
frame_locs = []
frames_limit = config.traceback_in_locations_limit.value
frames_limit = frames_limit if frames_limit >= 0 else 1000
@ -387,11 +391,13 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
n = len(frame_locs)
if n == 0:
return ir.Location.unknown()
loc = ir.Location.unknown()
elif n == 1:
return frame_locs[0]
loc = frame_locs[0]
else:
return ir.Location.callsite(frame_locs[0], frame_locs[1:])
loc = ir.Location.callsite(frame_locs[0], frame_locs[1:])
ctx.traceback_caches.traceback_cache[tb] = loc
return loc
def _source_info_to_location(
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
@ -569,11 +575,13 @@ class LoweringParameters:
@dataclasses.dataclass
class TracebackCaches:
traceback_cache: dict[xc.Traceback, ir.Location]
location_cache: dict[tuple[types.CodeType, int], ir.Location]
canonical_name_cache: dict[str, str]
is_user_file_cache: dict[str, bool]
def __init__(self):
self.traceback_cache = {}
self.location_cache = {}
self.canonical_name_cache = {}
self.is_user_file_cache = {}