mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Jax] Speedup stacks traceback
PiperOrigin-RevId: 595416135
This commit is contained in:
parent
afa2f1e420
commit
b6136795dd
@ -330,25 +330,50 @@ register_constant_handler(core.Token, _token_constant_handler)
|
||||
|
||||
# Source locations
|
||||
|
||||
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
|
||||
source_file = frame.file_name
|
||||
if pattern := config.hlo_source_file_canonicalization_regex.value:
|
||||
def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
|
||||
if file_name in caches.canonical_name_cache:
|
||||
return caches.canonical_name_cache[file_name]
|
||||
|
||||
source_file = file_name
|
||||
pattern = config.hlo_source_file_canonicalization_regex.value
|
||||
if pattern:
|
||||
source_file = re.sub(pattern, '', source_file)
|
||||
|
||||
caches.canonical_name_cache[file_name] = source_file
|
||||
return source_file
|
||||
|
||||
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
||||
def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
|
||||
if file_name in ctx.traceback_caches.is_user_file_cache:
|
||||
return ctx.traceback_caches.is_user_file_cache[file_name]
|
||||
|
||||
result = source_info_util.is_user_filename(file_name)
|
||||
ctx.traceback_caches.is_user_file_cache[file_name] = result
|
||||
return result
|
||||
|
||||
def _raw_frame_to_frame(ctx: ModuleContext,
|
||||
code: source_info_util.types.CodeType, lasti: int):
|
||||
key = (code.co_filename, lasti)
|
||||
if key in ctx.traceback_caches.raw_frame_to_frame_cache:
|
||||
return ctx.traceback_caches.raw_frame_to_frame_cache[key]
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
ctx.traceback_caches.raw_frame_to_frame_cache[key] = frame
|
||||
return frame
|
||||
|
||||
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
"""Converts a full traceback to a callsite() MLIR location."""
|
||||
frame_locs = []
|
||||
for code, lasti in zip(*tb.raw_frames()):
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
if source_info_util.is_user_filename(frame.file_name):
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
name_loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
||||
frame_locs.append(name_loc)
|
||||
if not _is_user_file(ctx, code.co_filename):
|
||||
continue
|
||||
|
||||
frame = _raw_frame_to_frame(ctx, code, lasti)
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
name_loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
||||
frame_locs.append(name_loc)
|
||||
|
||||
if len(frame_locs) == 0:
|
||||
return ir.Location.unknown()
|
||||
@ -359,22 +384,22 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
||||
return ir.Location.callsite(frame_locs[0], frame_locs[1:])
|
||||
|
||||
def _source_info_to_location(
|
||||
primitive: core.Primitive, params: dict,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: source_info_util.NameStack) -> ir.Location:
|
||||
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
|
||||
source_info: source_info_util.SourceInfo) -> ir.Location:
|
||||
eqn_str = (f'{source_info.name_stack}/'
|
||||
f'{core.str_eqn_compact(primitive.name, params)}')
|
||||
if config.include_full_tracebacks_in_locations.value:
|
||||
if source_info.traceback is None:
|
||||
loc = ir.Location.unknown()
|
||||
else:
|
||||
loc = _traceback_to_location(source_info.traceback)
|
||||
loc = _traceback_to_location(ctx, source_info.traceback)
|
||||
else:
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
if frame is None:
|
||||
loc = ir.Location.unknown()
|
||||
else:
|
||||
loc = ir.Location.file(get_canonical_source_file(frame),
|
||||
loc = ir.Location.file(get_canonical_source_file(frame.file_name,
|
||||
ctx.traceback_caches),
|
||||
frame.start_line, frame.start_column)
|
||||
loc = ir.Location.name(eqn_str, childLoc=loc)
|
||||
# TODO(phawkins): also include primitive.name as the operator type.
|
||||
@ -532,6 +557,16 @@ class LoweringParameters:
|
||||
# native execution (and we can remove this parameter).
|
||||
replace_tokens_with_dummy: bool = True
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TracebackCaches:
|
||||
canonical_name_cache: dict[str, str]
|
||||
is_user_file_cache: dict[str, bool]
|
||||
raw_frame_to_frame_cache: dict[tuple[str, int], source_info_util.Frame]
|
||||
|
||||
def __init__(self):
|
||||
self.canonical_name_cache = {}
|
||||
self.is_user_file_cache = {}
|
||||
self.raw_frame_to_frame_cache = {}
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
@ -553,6 +588,9 @@ class ModuleContext:
|
||||
# Cached primitive lowerings.
|
||||
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
|
||||
# Cached traceback infromation.
|
||||
traceback_caches: TracebackCaches
|
||||
|
||||
lowering_parameters: LoweringParameters
|
||||
|
||||
@property
|
||||
@ -576,6 +614,7 @@ class ModuleContext:
|
||||
symbol_table: ir.SymbolTable | None = None,
|
||||
cached_primitive_lowerings: None | (dict[Any,
|
||||
func_dialect.FuncOp]) = None,
|
||||
traceback_caches: None | TracebackCaches = None,
|
||||
shape_poly_state = None):
|
||||
|
||||
self.context = context or make_ir_context()
|
||||
@ -588,6 +627,8 @@ class ModuleContext:
|
||||
self.name_stack = name_stack
|
||||
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
||||
else cached_primitive_lowerings)
|
||||
self.traceback_caches = (TracebackCaches() if traceback_caches is None
|
||||
else traceback_caches)
|
||||
self.channel_iterator = channel_iterator
|
||||
self.keepalives = keepalives
|
||||
self.host_callbacks = host_callbacks
|
||||
@ -1508,8 +1549,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=ctx.name_stack + eqn.source_info.name_stack)
|
||||
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
|
||||
ctx.name_stack)
|
||||
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
platform_rules: dict[str, LoweringRule] = {}
|
||||
|
@ -86,6 +86,8 @@ class LoweringContext:
|
||||
name_stack: source_info_util.NameStack
|
||||
mesh_context: MeshContext | None
|
||||
replace = dataclasses.replace
|
||||
traceback_caches: mlir.TracebackCaches
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -423,6 +425,7 @@ def lower_jaxpr_to_transform_func(
|
||||
arg_block_shapes,
|
||||
source_info_util.NameStack(),
|
||||
mesh_context=mesh_context,
|
||||
traceback_caches=mlir.TracebackCaches(),
|
||||
)
|
||||
return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
||||
*scalar_prefetch)
|
||||
@ -478,6 +481,7 @@ def lower_jaxpr_to_func(
|
||||
arg_block_shapes,
|
||||
source_info_util.NameStack(),
|
||||
mesh_context=mesh_context,
|
||||
traceback_caches=mlir.TracebackCaches(),
|
||||
)
|
||||
return jaxpr_subcomp(
|
||||
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
|
||||
@ -546,7 +550,7 @@ def jaxpr_subcomp(
|
||||
name_stack=ctx.name_stack + eqn.source_info.name_stack
|
||||
)
|
||||
loc = mlir._source_info_to_location(
|
||||
eqn.primitive, eqn.params, source_info, ctx.name_stack
|
||||
ctx, eqn.primitive, eqn.params, source_info
|
||||
)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
if eqn.primitive in lowering_rules:
|
||||
|
@ -1244,7 +1244,8 @@ def _make_op_metadata(primitive: core.Primitive,
|
||||
return xla_client.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
op_name=eqn_str,
|
||||
source_file=mlir.get_canonical_source_file(frame) if frame else None,
|
||||
source_file=mlir.get_canonical_source_file(
|
||||
frame.file_name if frame else "", mlir.TracebackCaches()),
|
||||
source_line=frame.start_line if frame else None)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user