Add spaces before dump instructions

Otherwise, the error message comes out weird and the instructions may be missed:

```
ValueError: Cannot lower jaxpr with verifier errors:
	'stablehlo.add' op requires compatible types for all operands and results
		at loc("jit(add)/jit(main)/add"("<module>"("/tmp/jax_exploration/main.py":7:6)))
	see current operation: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<f32>) -> tensor<i32>
		at loc("jit(add)/jit(main)/add"("<module>"("/tmp/jax_exploration/main.py":7:6)))Define JAX_DUMP_IR_TO to dump the module.
```

See also this error message in the docs: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#numpy-lax-xla-jax-api-layering (scroll horizontally to see the "Define JAX_DUMP_IR_TO ..." all the way to the right)

PiperOrigin-RevId: 639882828
This commit is contained in:
Michael Levesque-Dion 2024-06-03 12:56:28 -07:00 committed by jax authors
parent d99cd3e99d
commit 2bfafe35d2

View File

@ -970,7 +970,7 @@ def lower_jaxpr_to_module(
try:
if not ctx.module.operation.verify():
raise ValueError(
"Cannot lower jaxpr with verifier errors." +
"Cannot lower jaxpr with verifier errors. " +
dump_module_message(ctx.module, "verification"))
except ir.MLIRError as e:
msg_lines = ["Cannot lower jaxpr with verifier errors:"]
@ -981,7 +981,7 @@ def lower_jaxpr_to_module(
emit_diagnostic_info(n)
for d in e.error_diagnostics:
emit_diagnostic_info(d)
raise ValueError("\n".join(msg_lines) +
raise ValueError("\n".join(msg_lines) + "\n" +
dump_module_message(ctx.module, "verification")) from e
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,