mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a cache mapping complete traceback objects to MLIR IR locations.
This turns out to be profitable in at least one benchmark. PiperOrigin-RevId: 609726977
This commit is contained in:
parent
9088e5c6f6
commit
0080b36a3d
@ -360,6 +360,10 @@ def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
|
||||
|
||||
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
"""Converts a full traceback to a callsite() MLIR location."""
|
||||
loc = ctx.traceback_caches.traceback_cache.get(tb, None)
|
||||
if loc is not None:
|
||||
return loc
|
||||
|
||||
frame_locs = []
|
||||
frames_limit = config.traceback_in_locations_limit.value
|
||||
frames_limit = frames_limit if frames_limit >= 0 else 1000
|
||||
@ -387,11 +391,13 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
|
||||
n = len(frame_locs)
|
||||
if n == 0:
|
||||
return ir.Location.unknown()
|
||||
loc = ir.Location.unknown()
|
||||
elif n == 1:
|
||||
return frame_locs[0]
|
||||
loc = frame_locs[0]
|
||||
else:
|
||||
return ir.Location.callsite(frame_locs[0], frame_locs[1:])
|
||||
loc = ir.Location.callsite(frame_locs[0], frame_locs[1:])
|
||||
ctx.traceback_caches.traceback_cache[tb] = loc
|
||||
return loc
|
||||
|
||||
def _source_info_to_location(
|
||||
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
|
||||
@ -569,11 +575,13 @@ class LoweringParameters:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TracebackCaches:
|
||||
traceback_cache: dict[xc.Traceback, ir.Location]
|
||||
location_cache: dict[tuple[types.CodeType, int], ir.Location]
|
||||
canonical_name_cache: dict[str, str]
|
||||
is_user_file_cache: dict[str, bool]
|
||||
|
||||
def __init__(self):
|
||||
self.traceback_cache = {}
|
||||
self.location_cache = {}
|
||||
self.canonical_name_cache = {}
|
||||
self.is_user_file_cache = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user