diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1a4587300..e39bf36c8 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -330,25 +330,50 @@ register_constant_handler(core.Token, _token_constant_handler) # Source locations -def get_canonical_source_file(frame: source_info_util.Frame) -> str: - source_file = frame.file_name - if pattern := config.hlo_source_file_canonicalization_regex.value: +def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: + if file_name in caches.canonical_name_cache: + return caches.canonical_name_cache[file_name] + + source_file = file_name + pattern = config.hlo_source_file_canonicalization_regex.value + if pattern: source_file = re.sub(pattern, '', source_file) + + caches.canonical_name_cache[file_name] = source_file return source_file -def _traceback_to_location(tb: xc.Traceback) -> ir.Location: +def _is_user_file(ctx: ModuleContext, file_name: str) -> bool: + if file_name in ctx.traceback_caches.is_user_file_cache: + return ctx.traceback_caches.is_user_file_cache[file_name] + + result = source_info_util.is_user_filename(file_name) + ctx.traceback_caches.is_user_file_cache[file_name] = result + return result + +def _raw_frame_to_frame(ctx: ModuleContext, + code: source_info_util.types.CodeType, lasti: int): + key = (code.co_filename, lasti) + if key in ctx.traceback_caches.raw_frame_to_frame_cache: + return ctx.traceback_caches.raw_frame_to_frame_cache[key] + frame = source_info_util.raw_frame_to_frame(code, lasti) + ctx.traceback_caches.raw_frame_to_frame_cache[key] = frame + return frame + +def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: """Converts a full traceback to a callsite() MLIR location.""" frame_locs = [] for code, lasti in zip(*tb.raw_frames()): - frame = source_info_util.raw_frame_to_frame(code, lasti) - if source_info_util.is_user_filename(frame.file_name): - file_loc = ir.Location.file( - get_canonical_source_file(frame), - frame.start_line, - frame.start_column, - ) - name_loc = ir.Location.name(frame.function_name, childLoc=file_loc) - frame_locs.append(name_loc) + if not _is_user_file(ctx, code.co_filename): + continue + + frame = _raw_frame_to_frame(ctx, code, lasti) + file_loc = ir.Location.file( + get_canonical_source_file(frame.file_name, ctx.traceback_caches), + frame.start_line, + frame.start_column, + ) + name_loc = ir.Location.name(frame.function_name, childLoc=file_loc) + frame_locs.append(name_loc) if len(frame_locs) == 0: return ir.Location.unknown() @@ -359,22 +384,22 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location: return ir.Location.callsite(frame_locs[0], frame_locs[1:]) def _source_info_to_location( - primitive: core.Primitive, params: dict, - source_info: source_info_util.SourceInfo, - name_stack: source_info_util.NameStack) -> ir.Location: + ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any], + source_info: source_info_util.SourceInfo) -> ir.Location: eqn_str = (f'{source_info.name_stack}/' f'{core.str_eqn_compact(primitive.name, params)}') if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: loc = ir.Location.unknown() else: - loc = _traceback_to_location(source_info.traceback) + loc = _traceback_to_location(ctx, source_info.traceback) else: frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else: - loc = ir.Location.file(get_canonical_source_file(frame), + loc = ir.Location.file(get_canonical_source_file(frame.file_name, + ctx.traceback_caches), frame.start_line, frame.start_column) loc = ir.Location.name(eqn_str, childLoc=loc) # TODO(phawkins): also include primitive.name as the operator type. @@ -532,6 +557,16 @@ class LoweringParameters: # native execution (and we can remove this parameter). replace_tokens_with_dummy: bool = True +@dataclasses.dataclass +class TracebackCaches: + canonical_name_cache: dict[str, str] + is_user_file_cache: dict[str, bool] + raw_frame_to_frame_cache: dict[tuple[str, int], source_info_util.Frame] + + def __init__(self): + self.canonical_name_cache = {} + self.is_user_file_cache = {} + self.raw_frame_to_frame_cache = {} @dataclasses.dataclass class ModuleContext: @@ -553,6 +588,9 @@ class ModuleContext: # Cached primitive lowerings. cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] + # Cached traceback infromation. + traceback_caches: TracebackCaches + lowering_parameters: LoweringParameters @property @@ -576,6 +614,7 @@ class ModuleContext: symbol_table: ir.SymbolTable | None = None, cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, + traceback_caches: None | TracebackCaches = None, shape_poly_state = None): self.context = context or make_ir_context() @@ -588,6 +627,8 @@ class ModuleContext: self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings) + self.traceback_caches = (TracebackCaches() if traceback_caches is None + else traceback_caches) self.channel_iterator = channel_iterator self.keepalives = keepalives self.host_callbacks = host_callbacks @@ -1508,8 +1549,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) source_info = eqn.source_info.replace( name_stack=ctx.name_stack + eqn.source_info.name_stack) - loc = _source_info_to_location(eqn.primitive, eqn.params, source_info, - ctx.name_stack) + loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 90e91e77c..a734effd3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -86,6 +86,8 @@ class LoweringContext: name_stack: source_info_util.NameStack mesh_context: MeshContext | None replace = dataclasses.replace + traceback_caches: mlir.TracebackCaches + @dataclasses.dataclass @@ -423,6 +425,7 @@ def lower_jaxpr_to_transform_func( arg_block_shapes, source_info_util.NameStack(), mesh_context=mesh_context, + traceback_caches=mlir.TracebackCaches(), ) return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices, *scalar_prefetch) @@ -478,6 +481,7 @@ def lower_jaxpr_to_func( arg_block_shapes, source_info_util.NameStack(), mesh_context=mesh_context, + traceback_caches=mlir.TracebackCaches(), ) return jaxpr_subcomp( lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch @@ -546,7 +550,7 @@ def jaxpr_subcomp( name_stack=ctx.name_stack + eqn.source_info.name_stack ) loc = mlir._source_info_to_location( - eqn.primitive, eqn.params, source_info, ctx.name_stack + ctx, eqn.primitive, eqn.params, source_info ) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in lowering_rules: diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 11f2d7ba3..4b477586a 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1244,7 +1244,8 @@ def _make_op_metadata(primitive: core.Primitive, return xla_client.OpMetadata( op_type=primitive.name, op_name=eqn_str, - source_file=mlir.get_canonical_source_file(frame) if frame else None, + source_file=mlir.get_canonical_source_file( + frame.file_name if frame else "", mlir.TracebackCaches()), source_line=frame.start_line if frame else None)