mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose. This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing. The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
This commit is contained in:
parent
7681493760
commit
d014f5dc5f
104
jax/_src/core.py
104
jax/_src/core.py
@ -149,51 +149,11 @@ class Jaxpr:
|
||||
def pretty_print(self, *, source_info=False, print_shapes=True,
|
||||
custom_pp_eqn_rules=True, name_stack=False,
|
||||
print_effects: bool = False, **kwargs):
|
||||
context = JaxprPpContext()
|
||||
settings = JaxprPpSettings(
|
||||
source_info=source_info,
|
||||
print_shapes=print_shapes,
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules,
|
||||
name_stack=name_stack,
|
||||
print_effects=print_effects)
|
||||
|
||||
# Compute how many times each jaxpr is used.
|
||||
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
|
||||
jaxpr_counts = Counter[Jaxpr]()
|
||||
s = deque([self])
|
||||
while s:
|
||||
jaxpr = s.popleft()
|
||||
jaxpr_counts[jaxpr] += 1
|
||||
for eqn in jaxpr.eqns:
|
||||
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
|
||||
name = eqn.params.get("name")
|
||||
if name is None:
|
||||
s.extend(jaxprs_in_params(eqn.params))
|
||||
continue
|
||||
name = name.strip("<>") # <lambda> -> lambda
|
||||
for subjaxpr in jaxprs_in_params(eqn.params):
|
||||
s.append(subjaxpr)
|
||||
names.setdefault(subjaxpr, name)
|
||||
|
||||
# Pull jaxprs occurring more than once to the top-level, making sure
|
||||
# that their names are unique.
|
||||
docs = []
|
||||
name_counts = Counter[str]()
|
||||
for jaxpr, c in jaxpr_counts.items():
|
||||
if c == 1:
|
||||
continue
|
||||
name = names[jaxpr]
|
||||
if (count := name_counts[name]) > 0:
|
||||
name_counts[name] += 1
|
||||
name += str(count)
|
||||
name_counts[name] += 1
|
||||
else:
|
||||
name_counts[name] += 1
|
||||
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
|
||||
context.used_names.add(name)
|
||||
context.top_level_jaxprs[jaxpr] = name
|
||||
docs.append(pp_jaxpr(self, context, settings))
|
||||
return pp.concat(docs).format(**kwargs)
|
||||
doc = pp_toplevel_jaxpr(
|
||||
self, source_info=source_info, print_shapes=print_shapes,
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack,
|
||||
print_effects=print_effects)
|
||||
return doc.format(**kwargs)
|
||||
|
||||
def _repr_pretty_(self, p, cycle):
|
||||
return p.text(self.pretty_print(use_color=True))
|
||||
@ -212,6 +172,7 @@ class Jaxpr:
|
||||
return jaxpr
|
||||
|
||||
|
||||
|
||||
def join_effects(*effects: Effects) -> Effects:
|
||||
return set().union(*effects) if effects else no_effects
|
||||
|
||||
@ -3164,6 +3125,55 @@ def _check_map(ctx_factory, prim, in_avals, params):
|
||||
|
||||
# ------------------- Jaxpr printed representation -------------------
|
||||
|
||||
def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
|
||||
custom_pp_eqn_rules=True, name_stack=False,
|
||||
print_effects: bool = False) -> pp.Doc:
|
||||
context = JaxprPpContext()
|
||||
settings = JaxprPpSettings(
|
||||
source_info=source_info,
|
||||
print_shapes=print_shapes,
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules,
|
||||
name_stack=name_stack,
|
||||
print_effects=print_effects)
|
||||
|
||||
# Compute how many times each jaxpr is used.
|
||||
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
|
||||
jaxpr_counts = Counter[Jaxpr]()
|
||||
s = deque([jaxpr_to_print])
|
||||
while s:
|
||||
jaxpr = s.popleft()
|
||||
jaxpr_counts[jaxpr] += 1
|
||||
for eqn in jaxpr.eqns:
|
||||
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
|
||||
name = eqn.params.get("name")
|
||||
if name is None:
|
||||
s.extend(jaxprs_in_params(eqn.params))
|
||||
continue
|
||||
name = name.strip("<>") # <lambda> -> lambda
|
||||
for subjaxpr in jaxprs_in_params(eqn.params):
|
||||
s.append(subjaxpr)
|
||||
names.setdefault(subjaxpr, name)
|
||||
|
||||
# Pull jaxprs occurring more than once to the top-level, making sure
|
||||
# that their names are unique.
|
||||
docs = []
|
||||
name_counts = Counter[str]()
|
||||
for jaxpr, c in jaxpr_counts.items():
|
||||
if c == 1:
|
||||
continue
|
||||
name = names[jaxpr]
|
||||
if (count := name_counts[name]) > 0:
|
||||
name_counts[name] += 1
|
||||
name += str(count)
|
||||
name_counts[name] += 1
|
||||
else:
|
||||
name_counts[name] += 1
|
||||
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
|
||||
context.used_names.add(name)
|
||||
context.top_level_jaxprs[jaxpr] = name
|
||||
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
|
||||
return pp.concat(docs)
|
||||
|
||||
|
||||
class JaxprPpSettings(NamedTuple):
|
||||
print_shapes: bool = True
|
||||
@ -3253,7 +3263,9 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
|
||||
) -> pp.Doc:
|
||||
rule = (_pp_eqn if not settings.custom_pp_eqn_rules else
|
||||
pp_eqn_rules.get(eqn.primitive, _pp_eqn))
|
||||
return rule(eqn, context, settings) # type: ignore[operator]
|
||||
doc = rule(eqn, context, settings) # type: ignore[operator]
|
||||
user_frame = source_info_util.user_frame(eqn.source_info)
|
||||
return doc if user_frame is None else pp.source_map(doc, user_frame)
|
||||
|
||||
def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
|
@ -31,7 +31,7 @@ from collections.abc import Sequence
|
||||
import enum
|
||||
from functools import partial
|
||||
import sys
|
||||
from typing import NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import util
|
||||
@ -69,12 +69,23 @@ CAN_USE_COLOR = _can_use_color()
|
||||
class Doc(util.StrictABC):
|
||||
__slots__ = ()
|
||||
|
||||
def format(self, width: int = 80, use_color: bool | None = None,
|
||||
annotation_prefix=" # ") -> str:
|
||||
def format(
|
||||
self, width: int = 80, *, use_color: bool | None = None,
|
||||
annotation_prefix: str = " # ",
|
||||
source_map: list[list[tuple[int, int, Any]]] | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Formats a pretty-printer document as a string.
|
||||
|
||||
Args:
|
||||
source_map: for each line in the output, contains a list of
|
||||
(start column, end column, source) tuples. Each tuple associates a
|
||||
region of output text with a source.
|
||||
"""
|
||||
if use_color is None:
|
||||
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
|
||||
return _format(self, width, use_color=use_color,
|
||||
annotation_prefix=annotation_prefix)
|
||||
annotation_prefix=annotation_prefix, source_map=source_map)
|
||||
|
||||
def __str__(self):
|
||||
return self.format()
|
||||
@ -147,6 +158,21 @@ class _NestDoc(Doc):
|
||||
def __repr__(self): return f"nest({self.n, self.child})"
|
||||
|
||||
|
||||
_NO_SOURCE = object()
|
||||
|
||||
class _SourceMapDoc(Doc):
|
||||
__slots__ = ("child", "source")
|
||||
child: Doc
|
||||
source: Any
|
||||
|
||||
def __init__(self, child: Doc, source: Any):
|
||||
assert isinstance(child, Doc), child
|
||||
self.child = child
|
||||
self.source = source
|
||||
|
||||
def __repr__(self): return f"source({self.child}, {self.source})"
|
||||
|
||||
|
||||
Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE",
|
||||
"MAGENTA", "CYAN", "WHITE", "RESET"])
|
||||
Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"])
|
||||
@ -193,7 +219,7 @@ def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]]
|
||||
agenda.append((i + doc.n, m, doc.child))
|
||||
elif isinstance(doc, _GroupDoc):
|
||||
agenda.append((i, _BreakMode.FLAT, doc.child))
|
||||
elif isinstance(doc, _ColorDoc):
|
||||
elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc):
|
||||
agenda.append((i, m, doc.child))
|
||||
else:
|
||||
raise ValueError("Invalid document ", doc)
|
||||
@ -224,7 +250,7 @@ def _sparse(doc: Doc) -> bool:
|
||||
agenda.append(doc.child)
|
||||
elif isinstance(doc, _GroupDoc):
|
||||
agenda.append(doc.child)
|
||||
elif isinstance(doc, _ColorDoc):
|
||||
elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc):
|
||||
agenda.append(doc.child)
|
||||
else:
|
||||
raise ValueError("Invalid document ", doc)
|
||||
@ -241,6 +267,7 @@ class _State(NamedTuple):
|
||||
mode: _BreakMode
|
||||
doc: Doc
|
||||
color: _ColorState
|
||||
source_map: Any
|
||||
|
||||
class _Line(NamedTuple):
|
||||
text: str
|
||||
@ -283,17 +310,29 @@ def _align_annotations(lines):
|
||||
|
||||
|
||||
|
||||
def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
|
||||
def _format(
|
||||
doc: Doc, width: int, *, use_color: bool, annotation_prefix: str,
|
||||
source_map: list[list[tuple[int, int, Any]]] | None
|
||||
) -> str:
|
||||
lines = []
|
||||
default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL)
|
||||
annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM)
|
||||
color_state = default_colors
|
||||
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)]
|
||||
source_start = 0 # The column at which the current source region starts.
|
||||
source = _NO_SOURCE # The currently active source region.
|
||||
line_source_map = [] # Source maps for the current line of text.
|
||||
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)]
|
||||
k = 0
|
||||
line_text = ""
|
||||
line_annotations = []
|
||||
while len(agenda) > 0:
|
||||
i, m, doc, color = agenda.pop()
|
||||
i, m, doc, color, agenda_source = agenda.pop()
|
||||
if source_map is not None and agenda_source != source:
|
||||
pos = len(line_text)
|
||||
if source_start != pos and source is not _NO_SOURCE:
|
||||
line_source_map.append((source_start, pos, source))
|
||||
source = agenda_source
|
||||
source_start = pos
|
||||
if isinstance(doc, _NilDoc):
|
||||
pass
|
||||
elif isinstance(doc, _TextDoc):
|
||||
@ -304,7 +343,7 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
|
||||
line_annotations.append(doc.annotation)
|
||||
k += len(doc.text)
|
||||
elif isinstance(doc, _ConcatDoc):
|
||||
agenda.extend(_State(i, m, d, color)
|
||||
agenda.extend(_State(i, m, d, color, source)
|
||||
for d in reversed(doc.children))
|
||||
elif isinstance(doc, _BreakDoc):
|
||||
if m == _BreakMode.BREAK:
|
||||
@ -313,6 +352,13 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
|
||||
annotation_colors)
|
||||
line_text += color_str
|
||||
lines.append(_Line(line_text, k, line_annotations))
|
||||
if source_map is not None:
|
||||
pos = len(line_text)
|
||||
if source_start != pos and source is not _NO_SOURCE:
|
||||
line_source_map.append((source_start, pos, source))
|
||||
source_map.append(line_source_map)
|
||||
line_source_map = []
|
||||
source_start = i
|
||||
line_text = " " * i
|
||||
line_annotations = []
|
||||
k = i
|
||||
@ -322,20 +368,22 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
|
||||
line_text += doc.text
|
||||
k += len(doc.text)
|
||||
elif isinstance(doc, _NestDoc):
|
||||
agenda.append(_State(i + doc.n, m, doc.child, color))
|
||||
agenda.append(_State(i + doc.n, m, doc.child, color, source))
|
||||
elif isinstance(doc, _GroupDoc):
|
||||
# In Lindig's paper, _fits is passed the remainder of the document.
|
||||
# I'm pretty sure that's a bug and we care only if the current group fits!
|
||||
if (_sparse(doc)
|
||||
and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])):
|
||||
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color))
|
||||
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source))
|
||||
else:
|
||||
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color))
|
||||
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source))
|
||||
elif isinstance(doc, _ColorDoc):
|
||||
color = _ColorState(doc.foreground or color.foreground,
|
||||
doc.background or color.background,
|
||||
doc.intensity or color.intensity)
|
||||
agenda.append(_State(i, m, doc.child, color))
|
||||
agenda.append(_State(i, m, doc.child, color, source))
|
||||
elif isinstance(doc, _SourceMapDoc):
|
||||
agenda.append(_State(i, m, doc.child, color, doc.source))
|
||||
else:
|
||||
raise ValueError("Invalid document ", doc)
|
||||
|
||||
@ -343,6 +391,11 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
|
||||
color_state, color_str = _update_color(use_color, color_state,
|
||||
annotation_colors)
|
||||
line_text += color_str
|
||||
if source_map is not None:
|
||||
pos = len(line_text)
|
||||
if source_start != pos and source is not _NO_SOURCE:
|
||||
line_source_map.append((source_start, pos, source))
|
||||
source_map.append(line_source_map)
|
||||
lines.append(_Line(line_text, k, line_annotations))
|
||||
lines = _align_annotations(lines)
|
||||
out = "\n".join(
|
||||
@ -406,6 +459,17 @@ def color(doc: Doc, *, foreground: Color | None = None,
|
||||
intensity=intensity)
|
||||
|
||||
|
||||
def source_map(doc: Doc, source: Any):
|
||||
"""Source mapping.
|
||||
|
||||
A source map associates a region of the pretty-printer's text output with a
|
||||
source location that produced it. For the purposes of the pretty printer a
|
||||
``source`` may be any object: we require only that we can compare sources for
|
||||
equality. A text region to source object mapping can be populated as a side
|
||||
output of the ``format`` method.
|
||||
"""
|
||||
return _SourceMapDoc(doc, source)
|
||||
|
||||
type_annotation = partial(color, intensity=Intensity.NORMAL,
|
||||
foreground=Color.MAGENTA)
|
||||
keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE)
|
||||
|
@ -1477,6 +1477,15 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "pretty_printer_test",
|
||||
srcs = ["pretty_printer_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
36
tests/pretty_printer_test.py
Normal file
36
tests/pretty_printer_test.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import pretty_printer as pp
|
||||
|
||||
|
||||
class PrettyPrinterTest(jtu.JaxTestCase):
|
||||
|
||||
def testSourceMap(self):
|
||||
doc = pp.concat([
|
||||
pp.text("abc"), pp.source_map(pp.text("def"), 101),
|
||||
pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77),
|
||||
pp.text("mn"),
|
||||
])
|
||||
source_map = []
|
||||
out = doc.format(width=8, source_map=source_map)
|
||||
self.assertEqual(out, "abcdefgh\nijklmn")
|
||||
self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user