rocm_jax/jax/pprint_util.py
Peter Hawkins 3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00

61 lines
1.7 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator as op
class PrettyPrint:
"""Crude Hughes-inspired pretty printer."""
def __init__(self, lines):
self.lines = lines
def indent(self, indent):
return PrettyPrint([(indent + orig_indent, s)
for orig_indent, s in self.lines])
def annotate(self, length, msg):
(i, s), *rest = self.lines
return PrettyPrint([(i, s.ljust(length) + f" [{msg}]")] + list(rest))
def __add__(self, rhs):
return PrettyPrint(self.lines + rhs.lines)
def __rshift__(self, rhs):
if not rhs.lines:
return self
if not self.lines:
return rhs
indent, s = self.lines[-1]
indented_block = rhs.indent(indent + len(s))
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
return PrettyPrint(self.lines[:-1]
+ [(indent, common_line)]
+ indented_block.lines[1:])
def __str__(self):
return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s):
return PrettyPrint([(0, line) for line in str(s).splitlines()])
def hcat(ps):
return functools.reduce(op.rshift, ps)
def vcat(ps):
return sum(ps, pp(''))