[Mosaic TPU] Append dump id to timestamp to make dump list ordered

PiperOrigin-RevId: 715488504
This commit is contained in:
Jevin Jiang 2025-01-14 12:43:32 -08:00 committed by jax authors
parent f122f17b27
commit 6851700ed4

View File

@ -306,13 +306,21 @@ def _lower_tpu_kernel(
except ir.MLIRError as e:
raise ValueError("The compiled module fails MLIR verification") from e
timestamp = time.time_ns()
dump_cnt = [0]
def get_dump_file_prefix() -> str:
s = f"{timestamp}-{dump_cnt[0]:04}"
dump_cnt[0] += 1
return s
with module.context as ctx, module.operation.location as _:
ctx.append_dialect_registry(mlir.upstream_dialects)
ctx.load_all_available_dialects()
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()
dump_mlir(module, "original", kernel_name)
dump_mlir(module, "original", get_dump_file_prefix(), kernel_name)
if _MOSAIC_ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
@ -323,7 +331,7 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")
dump_mlir(module, "post-hlo-conversion", get_dump_file_prefix(), kernel_name)
sl_cnt, l_cnt = target_shape
# Note: we don't pass the TpuTilingFlags here, since we don't know the
@ -338,7 +346,7 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-memref-layout")
dump_mlir(module, "post-infer-memref-layout", get_dump_file_prefix(), kernel_name)
pipeline = [
"canonicalize",
@ -346,7 +354,12 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-memref-layout-simplify")
dump_mlir(
module,
"post-infer-memref-layout-simplify",
get_dump_file_prefix(),
kernel_name,
)
try:
on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value
@ -360,7 +373,7 @@ def _lower_tpu_kernel(
"builtin.module(func.func(debug-assert-insertion))"
)
pipeline.run(module.operation)
dump_mlir(module, "post-assert-insertion")
dump_mlir(module, "post-assert-insertion", get_dump_file_prefix(), kernel_name)
elif checks:
checks.discard("bounds")
raise ValueError(
@ -376,7 +389,7 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-canonicalize-mosaic")
dump_mlir(module, "post-canonicalize-mosaic", get_dump_file_prefix(), kernel_name)
pipeline = [
(
@ -388,7 +401,7 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")
dump_mlir(module, "post-infer-vector-layout", get_dump_file_prefix(), kernel_name)
pipeline = [
(
@ -399,7 +412,7 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-relayout-insertion")
dump_mlir(module, "post-relayout-insertion", get_dump_file_prefix(), kernel_name)
mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
@ -412,7 +425,7 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-apply-vector-layout")
dump_mlir(module, "post-apply-vector-layout", get_dump_file_prefix(), kernel_name)
pipeline = [
"canonicalize",
@ -420,7 +433,12 @@ def _lower_tpu_kernel(
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-apply-vector-layout-simplify")
dump_mlir(
module,
"post-apply-vector-layout-simplify",
get_dump_file_prefix(),
kernel_name,
)
return module
@ -777,7 +795,9 @@ def _as_jax_callable(
return jax.jit(apply_kernel)
def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None):
def dump_mlir(
module: ir.Module, name: str, prefix: str, kernel_name: str | None = None
):
"""A helper function to dump mosaic mlir module"""
try:
should_dump = FLAGS["xla_mosaic_dump_to"].value
@ -788,6 +808,6 @@ def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None):
if outdir:
if kernel_name:
name = f"{kernel_name}-{name}"
path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt")
path = os.path.join(outdir, f"{prefix}-mosaic-dump-{name}-py.txt")
with open(path, "w") as f:
f.write(str(module))