mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #25797 from gnecula:print_debug_info
PiperOrigin-RevId: 713958983
This commit is contained in:
commit
228d3cef0b
@ -32,7 +32,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
{jax-issue}`#25606` for more details.
|
||||
* Support added for user defined state in the FFI via the new
|
||||
{func}`jax.ffi.register_ffi_type_id` function.
|
||||
|
||||
* The AOT lowering `.as_text()` method now supports the `debug_info` option
|
||||
to include debugging information, e.g., source location, in the output.
|
||||
* Deprecations
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
|
@ -218,6 +218,8 @@ a text representation. Compiled functions do the same, and also offer cost and
|
||||
memory analyses from the compiler. All of these are provided via methods on the
|
||||
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
|
||||
`lowered.as_text()` and `compiled.cost_analysis()` above).
|
||||
You can obtain more debbugging information, e.g., source location,
|
||||
by using the `debug_info` parameter to `lowered.as_text()`.
|
||||
|
||||
These methods are meant as an aid for manual inspection and debugging, not as a
|
||||
reliably programmable API. Their availability and output vary by compiler,
|
||||
|
@ -162,7 +162,8 @@ class Lowering(Protocol):
|
||||
"""Compile and return a corresponding ``Executable``."""
|
||||
raise NotImplementedError
|
||||
|
||||
def as_text(self, dialect: str | None = None) -> str:
|
||||
def as_text(self, dialect: str | None = None, *,
|
||||
debug_info: bool = False) -> str:
|
||||
"""A human-readable text representation of this lowering.
|
||||
|
||||
Intended for visualization and debugging purposes. This need not be a valid
|
||||
@ -340,13 +341,18 @@ class XlaLowering(Lowering):
|
||||
self, compiler_options: CompilerOptions | None = None) -> Executable:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def as_text(self, dialect: str | None = None) -> str:
|
||||
def as_text(self, dialect: str | None = None,
|
||||
*,
|
||||
debug_info: bool = False) -> str:
|
||||
if dialect is None:
|
||||
dialect = "stablehlo"
|
||||
if dialect == "stablehlo":
|
||||
return str(self.stablehlo())
|
||||
return mlir.module_to_string(self.stablehlo(),
|
||||
enable_debug_info=debug_info)
|
||||
elif dialect == "hlo":
|
||||
return self.hlo().as_hlo_text()
|
||||
print_opts = xc._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = debug_info
|
||||
return self.hlo().as_hlo_module().to_string(print_opts)
|
||||
else:
|
||||
raise ValueError(f"unknown dialect: {dialect}")
|
||||
|
||||
@ -675,16 +681,21 @@ class Lowered(Stage):
|
||||
no_kwargs=self._no_kwargs,
|
||||
)
|
||||
|
||||
def as_text(self, dialect: str | None = None) -> str:
|
||||
def as_text(self, dialect: str | None = None, *,
|
||||
debug_info: bool = False) -> str:
|
||||
"""A human-readable text representation of this lowering.
|
||||
|
||||
Intended for visualization and debugging purposes. This need not be a valid
|
||||
nor reliable serialization. It is relayed directly to external callers.
|
||||
nor reliable serialization.
|
||||
Use `jax.export` if you want reliable and portable serialization.
|
||||
|
||||
Args:
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo",
|
||||
or "hlo").
|
||||
debug_info: Whether to include debugging information,
|
||||
e.g., source location.
|
||||
"""
|
||||
return self._lowering.as_text(dialect)
|
||||
return self._lowering.as_text(dialect, debug_info=debug_info)
|
||||
|
||||
def compiler_ir(self, dialect: str | None = None) -> Any | None:
|
||||
"""An arbitrary object representation of this lowering.
|
||||
@ -692,12 +703,14 @@ class Lowered(Stage):
|
||||
Intended for debugging purposes. This is not a valid nor reliable
|
||||
serialization. The output has no guarantee of consistency across
|
||||
invocations.
|
||||
Use `jax.export` if you want reliable and portable serialization.
|
||||
|
||||
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||||
runtime.
|
||||
|
||||
Args:
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo",
|
||||
or "hlo").
|
||||
"""
|
||||
try:
|
||||
return self._lowering.compiler_ir(dialect)
|
||||
|
@ -112,6 +112,21 @@ class JaxAotTest(jtu.JaxTestCase):
|
||||
topo.platform_version, aot_topo.devices[0].client.platform_version
|
||||
)
|
||||
|
||||
def test_lower_as_text_with_and_without_debug_info(self):
|
||||
def my_function(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
lowered = jax.jit(my_function).lower(42.)
|
||||
stablehlo = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertRegex(stablehlo, r"sine.* loc")
|
||||
stablehlo = lowered.as_text("stablehlo")
|
||||
self.assertNotRegex(stablehlo, r"sine.* loc")
|
||||
|
||||
hlo = lowered.as_text("hlo", debug_info=True)
|
||||
self.assertRegex(hlo, r"sine.*metadata=.*source_file=.*")
|
||||
hlo = lowered.as_text("hlo")
|
||||
self.assertNotRegex(hlo, r"sine.*metadata=.*source_file=.*")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1295,7 +1295,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertNotIn("args[0]", hlo_str)
|
||||
@ -1303,10 +1303,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
|
||||
hlo_str = mlir.module_to_string(
|
||||
lowered.compiler_ir('stablehlo'),
|
||||
enable_debug_info=False,
|
||||
)
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=False)
|
||||
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
|
||||
self.assertNotIn(s, hlo_str)
|
||||
|
||||
@ -1315,9 +1312,10 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(x, y, *args, **kwargs):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
ir = jax.jit(f, static_argnums=static_argnums).lower(
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('stablehlo')
|
||||
hlo_str = mlir.module_to_string(ir)
|
||||
lowered = jax.jit(f, static_argnums=static_argnums).lower(
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.)
|
||||
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertNotIn("args[0]", hlo_str)
|
||||
@ -1325,7 +1323,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
|
||||
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=False)
|
||||
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
|
||||
self.assertNotIn(s, hlo_str)
|
||||
|
||||
@ -1334,9 +1332,9 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(x, y, *args, **kwargs):
|
||||
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
|
||||
|
||||
ir = jax.jit(f, static_argnames=static_argnames).lower(
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('stablehlo')
|
||||
hlo_str = mlir.module_to_string(ir)
|
||||
lowered = jax.jit(f, static_argnames=static_argnames).lower(
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.)
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertNotIn("args[0]", hlo_str)
|
||||
@ -1346,7 +1344,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertNotIn("kwargs['a']", hlo_str)
|
||||
self.assertNotIn("kwargs['b']", hlo_str)
|
||||
|
||||
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=False)
|
||||
for s in (
|
||||
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
|
||||
"kwargs['w']", "kwargs['a']", "kwargs['b']"
|
||||
@ -1357,8 +1355,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(x, y, z):
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('stablehlo')
|
||||
hlo_str = mlir.module_to_string(ir)
|
||||
hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True)
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
|
@ -21,17 +21,14 @@ from jax import lax
|
||||
from jax._src.pjit import pjit
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import ad_checkpoint
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
def _get_hlo(f):
|
||||
def wrapped(*args, **kwargs):
|
||||
c = jax.jit(f).lower(*args, **kwargs).compiler_ir('hlo')
|
||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = True
|
||||
return c.as_hlo_module().to_string(print_opts)
|
||||
return jax.jit(f).lower(*args, **kwargs).as_text('hlo', debug_info=True)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
|
@ -29,7 +29,6 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
@ -48,7 +47,7 @@ def _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
|
||||
def _check_cudnn_backend(fn, *args, **kwargs):
|
||||
lowered = jax.jit(fn).lower(*args, **kwargs)
|
||||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
hlo = lowered.as_text('stablehlo', debug_info=True)
|
||||
return '__cudnn$fmha' in hlo
|
||||
|
||||
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
|
||||
|
@ -46,7 +46,6 @@ from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lax import parallel
|
||||
from jax._src.lib import xla_extension
|
||||
@ -2109,7 +2108,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
lowered = jax.pmap(f).lower(
|
||||
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
|
||||
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertIn("args[0]", hlo_str)
|
||||
@ -2123,7 +2122,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
|
||||
[jnp.array([3])])
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
hlo_str = lowered.as_text("stablehlo", debug_info=True)
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
|
@ -40,7 +40,6 @@ from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.mesh import AbstractMesh
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
@ -1312,7 +1311,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
in_specs=P('i'), out_specs=P('i'))(x)
|
||||
return x
|
||||
|
||||
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
|
||||
hlo_str = jax.jit(foo).lower(x).as_text("stablehlo", debug_info=True)
|
||||
if config.use_shardy_partitioner.value:
|
||||
if len(jax.devices()) > 1:
|
||||
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user