[Jax] Speedup stacks traceback

PiperOrigin-RevId: 595416135
This commit is contained in:
jax authors 2024-01-03 09:00:29 -08:00
parent afa2f1e420
commit b6136795dd
3 changed files with 67 additions and 22 deletions

View File

@ -330,25 +330,50 @@ register_constant_handler(core.Token, _token_constant_handler)
# Source locations
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
source_file = frame.file_name
if pattern := config.hlo_source_file_canonicalization_regex.value:
def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
if file_name in caches.canonical_name_cache:
return caches.canonical_name_cache[file_name]
source_file = file_name
pattern = config.hlo_source_file_canonicalization_regex.value
if pattern:
source_file = re.sub(pattern, '', source_file)
caches.canonical_name_cache[file_name] = source_file
return source_file
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
if file_name in ctx.traceback_caches.is_user_file_cache:
return ctx.traceback_caches.is_user_file_cache[file_name]
result = source_info_util.is_user_filename(file_name)
ctx.traceback_caches.is_user_file_cache[file_name] = result
return result
def _raw_frame_to_frame(ctx: ModuleContext,
code: source_info_util.types.CodeType, lasti: int):
key = (code.co_filename, lasti)
if key in ctx.traceback_caches.raw_frame_to_frame_cache:
return ctx.traceback_caches.raw_frame_to_frame_cache[key]
frame = source_info_util.raw_frame_to_frame(code, lasti)
ctx.traceback_caches.raw_frame_to_frame_cache[key] = frame
return frame
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
frame_locs = []
for code, lasti in zip(*tb.raw_frames()):
frame = source_info_util.raw_frame_to_frame(code, lasti)
if source_info_util.is_user_filename(frame.file_name):
file_loc = ir.Location.file(
get_canonical_source_file(frame),
frame.start_line,
frame.start_column,
)
name_loc = ir.Location.name(frame.function_name, childLoc=file_loc)
frame_locs.append(name_loc)
if not _is_user_file(ctx, code.co_filename):
continue
frame = _raw_frame_to_frame(ctx, code, lasti)
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
name_loc = ir.Location.name(frame.function_name, childLoc=file_loc)
frame_locs.append(name_loc)
if len(frame_locs) == 0:
return ir.Location.unknown()
@ -359,22 +384,22 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
return ir.Location.callsite(frame_locs[0], frame_locs[1:])
def _source_info_to_location(
primitive: core.Primitive, params: dict,
source_info: source_info_util.SourceInfo,
name_stack: source_info_util.NameStack) -> ir.Location:
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
source_info: source_info_util.SourceInfo) -> ir.Location:
eqn_str = (f'{source_info.name_stack}/'
f'{core.str_eqn_compact(primitive.name, params)}')
if config.include_full_tracebacks_in_locations.value:
if source_info.traceback is None:
loc = ir.Location.unknown()
else:
loc = _traceback_to_location(source_info.traceback)
loc = _traceback_to_location(ctx, source_info.traceback)
else:
frame = source_info_util.user_frame(source_info)
if frame is None:
loc = ir.Location.unknown()
else:
loc = ir.Location.file(get_canonical_source_file(frame),
loc = ir.Location.file(get_canonical_source_file(frame.file_name,
ctx.traceback_caches),
frame.start_line, frame.start_column)
loc = ir.Location.name(eqn_str, childLoc=loc)
# TODO(phawkins): also include primitive.name as the operator type.
@ -532,6 +557,16 @@ class LoweringParameters:
# native execution (and we can remove this parameter).
replace_tokens_with_dummy: bool = True
@dataclasses.dataclass
class TracebackCaches:
canonical_name_cache: dict[str, str]
is_user_file_cache: dict[str, bool]
raw_frame_to_frame_cache: dict[tuple[str, int], source_info_util.Frame]
def __init__(self):
self.canonical_name_cache = {}
self.is_user_file_cache = {}
self.raw_frame_to_frame_cache = {}
@dataclasses.dataclass
class ModuleContext:
@ -553,6 +588,9 @@ class ModuleContext:
# Cached primitive lowerings.
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
# Cached traceback infromation.
traceback_caches: TracebackCaches
lowering_parameters: LoweringParameters
@property
@ -576,6 +614,7 @@ class ModuleContext:
symbol_table: ir.SymbolTable | None = None,
cached_primitive_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
traceback_caches: None | TracebackCaches = None,
shape_poly_state = None):
self.context = context or make_ir_context()
@ -588,6 +627,8 @@ class ModuleContext:
self.name_stack = name_stack
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
else cached_primitive_lowerings)
self.traceback_caches = (TracebackCaches() if traceback_caches is None
else traceback_caches)
self.channel_iterator = channel_iterator
self.keepalives = keepalives
self.host_callbacks = host_callbacks
@ -1508,8 +1549,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
ctx.name_stack)
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
platform_rules: dict[str, LoweringRule] = {}

View File

@ -86,6 +86,8 @@ class LoweringContext:
name_stack: source_info_util.NameStack
mesh_context: MeshContext | None
replace = dataclasses.replace
traceback_caches: mlir.TracebackCaches
@dataclasses.dataclass
@ -423,6 +425,7 @@ def lower_jaxpr_to_transform_func(
arg_block_shapes,
source_info_util.NameStack(),
mesh_context=mesh_context,
traceback_caches=mlir.TracebackCaches(),
)
return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
*scalar_prefetch)
@ -478,6 +481,7 @@ def lower_jaxpr_to_func(
arg_block_shapes,
source_info_util.NameStack(),
mesh_context=mesh_context,
traceback_caches=mlir.TracebackCaches(),
)
return jaxpr_subcomp(
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
@ -546,7 +550,7 @@ def jaxpr_subcomp(
name_stack=ctx.name_stack + eqn.source_info.name_stack
)
loc = mlir._source_info_to_location(
eqn.primitive, eqn.params, source_info, ctx.name_stack
ctx, eqn.primitive, eqn.params, source_info
)
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive in lowering_rules:

View File

@ -1244,7 +1244,8 @@ def _make_op_metadata(primitive: core.Primitive,
return xla_client.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=mlir.get_canonical_source_file(frame) if frame else None,
source_file=mlir.get_canonical_source_file(
frame.file_name if frame else "", mlir.TracebackCaches()),
source_line=frame.start_line if frame else None)