Merge pull request #25797 from gnecula:print_debug_info

PiperOrigin-RevId: 713958983
This commit is contained in:
jax authors 2025-01-10 01:39:34 -08:00
commit 228d3cef0b
9 changed files with 59 additions and 37 deletions

View File

@ -32,7 +32,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{jax-issue}`#25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.
* The AOT lowering `.as_text()` method now supports the `debug_info` option
to include debugging information, e.g., source location, in the output.
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name

View File

@ -218,6 +218,8 @@ a text representation. Compiled functions do the same, and also offer cost and
memory analyses from the compiler. All of these are provided via methods on the
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
`lowered.as_text()` and `compiled.cost_analysis()` above).
You can obtain more debbugging information, e.g., source location,
by using the `debug_info` parameter to `lowered.as_text()`.
These methods are meant as an aid for manual inspection and debugging, not as a
reliably programmable API. Their availability and output vary by compiler,

View File

@ -162,7 +162,8 @@ class Lowering(Protocol):
"""Compile and return a corresponding ``Executable``."""
raise NotImplementedError
def as_text(self, dialect: str | None = None) -> str:
def as_text(self, dialect: str | None = None, *,
debug_info: bool = False) -> str:
"""A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a valid
@ -340,13 +341,18 @@ class XlaLowering(Lowering):
self, compiler_options: CompilerOptions | None = None) -> Executable:
raise NotImplementedError("must override")
def as_text(self, dialect: str | None = None) -> str:
def as_text(self, dialect: str | None = None,
*,
debug_info: bool = False) -> str:
if dialect is None:
dialect = "stablehlo"
if dialect == "stablehlo":
return str(self.stablehlo())
return mlir.module_to_string(self.stablehlo(),
enable_debug_info=debug_info)
elif dialect == "hlo":
return self.hlo().as_hlo_text()
print_opts = xc._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = debug_info
return self.hlo().as_hlo_module().to_string(print_opts)
else:
raise ValueError(f"unknown dialect: {dialect}")
@ -675,16 +681,21 @@ class Lowered(Stage):
no_kwargs=self._no_kwargs,
)
def as_text(self, dialect: str | None = None) -> str:
def as_text(self, dialect: str | None = None, *,
debug_info: bool = False) -> str:
"""A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a valid
nor reliable serialization. It is relayed directly to external callers.
nor reliable serialization.
Use `jax.export` if you want reliable and portable serialization.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo",
or "hlo").
debug_info: Whether to include debugging information,
e.g., source location.
"""
return self._lowering.as_text(dialect)
return self._lowering.as_text(dialect, debug_info=debug_info)
def compiler_ir(self, dialect: str | None = None) -> Any | None:
"""An arbitrary object representation of this lowering.
@ -692,12 +703,14 @@ class Lowered(Stage):
Intended for debugging purposes. This is not a valid nor reliable
serialization. The output has no guarantee of consistency across
invocations.
Use `jax.export` if you want reliable and portable serialization.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo",
or "hlo").
"""
try:
return self._lowering.compiler_ir(dialect)

View File

@ -112,6 +112,21 @@ class JaxAotTest(jtu.JaxTestCase):
topo.platform_version, aot_topo.devices[0].client.platform_version
)
def test_lower_as_text_with_and_without_debug_info(self):
def my_function(x):
return jnp.sin(x)
lowered = jax.jit(my_function).lower(42.)
stablehlo = lowered.as_text("stablehlo", debug_info=True)
self.assertRegex(stablehlo, r"sine.* loc")
stablehlo = lowered.as_text("stablehlo")
self.assertNotRegex(stablehlo, r"sine.* loc")
hlo = lowered.as_text("hlo", debug_info=True)
self.assertRegex(hlo, r"sine.*metadata=.*source_file=.*")
hlo = lowered.as_text("hlo")
self.assertNotRegex(hlo, r"sine.*metadata=.*source_file=.*")
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1295,7 +1295,7 @@ class JitTest(jtu.BufferDonationTestCase):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
@ -1303,10 +1303,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
hlo_str = mlir.module_to_string(
lowered.compiler_ir('stablehlo'),
enable_debug_info=False,
)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@ -1315,9 +1312,10 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
ir = jax.jit(f, static_argnums=static_argnums).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('stablehlo')
hlo_str = mlir.module_to_string(ir)
lowered = jax.jit(f, static_argnums=static_argnums).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
@ -1325,7 +1323,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@ -1334,9 +1332,9 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
ir = jax.jit(f, static_argnames=static_argnames).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('stablehlo')
hlo_str = mlir.module_to_string(ir)
lowered = jax.jit(f, static_argnames=static_argnames).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
@ -1346,7 +1344,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertNotIn("kwargs['a']", hlo_str)
self.assertNotIn("kwargs['b']", hlo_str)
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in (
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
"kwargs['w']", "kwargs['a']", "kwargs['b']"
@ -1357,8 +1355,7 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x, y, z):
return {'a': x, 'b': [y]}
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('stablehlo')
hlo_str = mlir.module_to_string(ir)
hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)

View File

@ -21,17 +21,14 @@ from jax import lax
from jax._src.pjit import pjit
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src.lib import xla_client
from jax._src import ad_checkpoint
jax.config.parse_flags_with_absl()
def _get_hlo(f):
def wrapped(*args, **kwargs):
c = jax.jit(f).lower(*args, **kwargs).compiler_ir('hlo')
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = True
return c.as_hlo_module().to_string(print_opts)
return jax.jit(f).lower(*args, **kwargs).as_text('hlo', debug_info=True)
return wrapped

View File

@ -29,7 +29,6 @@ from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib import cuda_versions
from jax.test_util import check_grads
from jax import nn
@ -48,7 +47,7 @@ def _is_required_cudnn_version_satisfied(min_cudnn_version):
def _check_cudnn_backend(fn, *args, **kwargs):
lowered = jax.jit(fn).lower(*args, **kwargs)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
hlo = lowered.as_text('stablehlo', debug_info=True)
return '__cudnn$fmha' in hlo
_cudnn_dbias_error = 'cuDNN only supports bias gradient'

View File

@ -46,7 +46,6 @@ from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lax import parallel
from jax._src.lib import xla_extension
@ -2109,7 +2108,7 @@ class PythonPmapTest(jtu.JaxTestCase):
lowered = jax.pmap(f).lower(
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertIn("args[0]", hlo_str)
@ -2123,7 +2122,7 @@ class PythonPmapTest(jtu.JaxTestCase):
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
[jnp.array([3])])
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)

View File

@ -40,7 +40,6 @@ from jax._src.lib.mlir.dialects import sdy
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
from jax._src.mesh import AbstractMesh
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
@ -1312,7 +1311,7 @@ class ShardMapTest(jtu.JaxTestCase):
in_specs=P('i'), out_specs=P('i'))(x)
return x
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
hlo_str = jax.jit(foo).lower(x).as_text("stablehlo", debug_info=True)
if config.use_shardy_partitioner.value:
if len(jax.devices()) > 1:
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))