mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add spaces before dump instructions
Otherwise, the error message comes out weird and the instructions may be missed: ``` ValueError: Cannot lower jaxpr with verifier errors: 'stablehlo.add' op requires compatible types for all operands and results at loc("jit(add)/jit(main)/add"("<module>"("/tmp/jax_exploration/main.py":7:6))) see current operation: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<f32>) -> tensor<i32> at loc("jit(add)/jit(main)/add"("<module>"("/tmp/jax_exploration/main.py":7:6)))Define JAX_DUMP_IR_TO to dump the module. ``` See also this error message in the docs: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#numpy-lax-xla-jax-api-layering (scroll horizontally to see the "Define JAX_DUMP_IR_TO ..." all the way to the right) PiperOrigin-RevId: 639882828
This commit is contained in:
parent
d99cd3e99d
commit
2bfafe35d2
@ -970,7 +970,7 @@ def lower_jaxpr_to_module(
|
||||
try:
|
||||
if not ctx.module.operation.verify():
|
||||
raise ValueError(
|
||||
"Cannot lower jaxpr with verifier errors." +
|
||||
"Cannot lower jaxpr with verifier errors. " +
|
||||
dump_module_message(ctx.module, "verification"))
|
||||
except ir.MLIRError as e:
|
||||
msg_lines = ["Cannot lower jaxpr with verifier errors:"]
|
||||
@ -981,7 +981,7 @@ def lower_jaxpr_to_module(
|
||||
emit_diagnostic_info(n)
|
||||
for d in e.error_diagnostics:
|
||||
emit_diagnostic_info(d)
|
||||
raise ValueError("\n".join(msg_lines) +
|
||||
raise ValueError("\n".join(msg_lines) + "\n" +
|
||||
dump_module_message(ctx.module, "verification")) from e
|
||||
|
||||
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
|
||||
|
Loading…
x
Reference in New Issue
Block a user