rocm_jax/tests/pretty_printer_test.py
Peter Hawkins d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00

37 lines
1.2 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src import pretty_printer as pp
class PrettyPrinterTest(jtu.JaxTestCase):
def testSourceMap(self):
doc = pp.concat([
pp.text("abc"), pp.source_map(pp.text("def"), 101),
pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77),
pp.text("mn"),
])
source_map = []
out = doc.format(width=8, source_map=source_map)
self.assertEqual(out, "abcdefgh\nijklmn")
self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())