mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] Append dump id to timestamp to make dump list ordered
PiperOrigin-RevId: 715488504
This commit is contained in:
parent
f122f17b27
commit
6851700ed4
@ -306,13 +306,21 @@ def _lower_tpu_kernel(
|
||||
except ir.MLIRError as e:
|
||||
raise ValueError("The compiled module fails MLIR verification") from e
|
||||
|
||||
timestamp = time.time_ns()
|
||||
dump_cnt = [0]
|
||||
|
||||
def get_dump_file_prefix() -> str:
|
||||
s = f"{timestamp}-{dump_cnt[0]:04}"
|
||||
dump_cnt[0] += 1
|
||||
return s
|
||||
|
||||
with module.context as ctx, module.operation.location as _:
|
||||
ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||
ctx.load_all_available_dialects()
|
||||
tpu.register_dialect(ctx)
|
||||
mhlo.register_mhlo_dialect(ctx)
|
||||
mhlo.register_mhlo_passes()
|
||||
dump_mlir(module, "original", kernel_name)
|
||||
dump_mlir(module, "original", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
if _MOSAIC_ALLOW_HLO.value:
|
||||
# Run hlo dialect conversion: hlo -> linalg -> vector.
|
||||
@ -323,7 +331,7 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-hlo-conversion")
|
||||
dump_mlir(module, "post-hlo-conversion", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
sl_cnt, l_cnt = target_shape
|
||||
# Note: we don't pass the TpuTilingFlags here, since we don't know the
|
||||
@ -338,7 +346,7 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-infer-memref-layout")
|
||||
dump_mlir(module, "post-infer-memref-layout", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
pipeline = [
|
||||
"canonicalize",
|
||||
@ -346,7 +354,12 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-infer-memref-layout-simplify")
|
||||
dump_mlir(
|
||||
module,
|
||||
"post-infer-memref-layout-simplify",
|
||||
get_dump_file_prefix(),
|
||||
kernel_name,
|
||||
)
|
||||
|
||||
try:
|
||||
on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value
|
||||
@ -360,7 +373,7 @@ def _lower_tpu_kernel(
|
||||
"builtin.module(func.func(debug-assert-insertion))"
|
||||
)
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-assert-insertion")
|
||||
dump_mlir(module, "post-assert-insertion", get_dump_file_prefix(), kernel_name)
|
||||
elif checks:
|
||||
checks.discard("bounds")
|
||||
raise ValueError(
|
||||
@ -376,7 +389,7 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-canonicalize-mosaic")
|
||||
dump_mlir(module, "post-canonicalize-mosaic", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
pipeline = [
|
||||
(
|
||||
@ -388,7 +401,7 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-infer-vector-layout")
|
||||
dump_mlir(module, "post-infer-vector-layout", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
pipeline = [
|
||||
(
|
||||
@ -399,7 +412,7 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-relayout-insertion")
|
||||
dump_mlir(module, "post-relayout-insertion", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
mxu_size = 128 if hardware_generation < 6 else 256
|
||||
pipeline = [
|
||||
@ -412,7 +425,7 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-apply-vector-layout")
|
||||
dump_mlir(module, "post-apply-vector-layout", get_dump_file_prefix(), kernel_name)
|
||||
|
||||
pipeline = [
|
||||
"canonicalize",
|
||||
@ -420,7 +433,12 @@ def _lower_tpu_kernel(
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-apply-vector-layout-simplify")
|
||||
dump_mlir(
|
||||
module,
|
||||
"post-apply-vector-layout-simplify",
|
||||
get_dump_file_prefix(),
|
||||
kernel_name,
|
||||
)
|
||||
|
||||
return module
|
||||
|
||||
@ -777,7 +795,9 @@ def _as_jax_callable(
|
||||
return jax.jit(apply_kernel)
|
||||
|
||||
|
||||
def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None):
|
||||
def dump_mlir(
|
||||
module: ir.Module, name: str, prefix: str, kernel_name: str | None = None
|
||||
):
|
||||
"""A helper function to dump mosaic mlir module"""
|
||||
try:
|
||||
should_dump = FLAGS["xla_mosaic_dump_to"].value
|
||||
@ -788,6 +808,6 @@ def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None):
|
||||
if outdir:
|
||||
if kernel_name:
|
||||
name = f"{kernel_name}-{name}"
|
||||
path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt")
|
||||
path = os.path.join(outdir, f"{prefix}-mosaic-dump-{name}-py.txt")
|
||||
with open(path, "w") as f:
|
||||
f.write(str(module))
|
||||
|
Loading…
x
Reference in New Issue
Block a user