Adds a Wadler-Lindig pretty printer.

Changes jaxpr printing to use it.
This commit is contained in:
Peter Hawkins 2021-09-24 22:08:42 -04:00
parent 001d20d250
commit 5fa4613e99
15 changed files with 777 additions and 435 deletions

View File

@ -1,3 +1,4 @@
colorama>=0.4.4
flatbuffers==2.0
pillow>=8.3.1
pytest-benchmark

View File

@ -92,11 +92,11 @@ For example, here is the jaxpr produced for the function ``func1`` below
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Here there are no constvars, ``a`` and ``b`` are the input variables
@ -129,11 +129,11 @@ jaxpr as before
... return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
@ -155,11 +155,11 @@ before (with two input vars, one for each element of the input tuple)
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
@ -209,20 +209,17 @@ For example:
... arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a b.
let c = convert_element_type[ new_dtype=int32
weak_type=False ] a
d = clamp 0 c 2
e = cond[ branches=( { lambda ; a.
let b = add a 1.0
in (b,) }
{ lambda ; a.
let b = sub a 2.0
in (b,) }
{ lambda ; a.
let b = add a 3.0
in (b,) } )
linear=(False,) ] d b
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0 c 2
e:f32[] = cond[
branches=(
{ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
{ lambda ; a:f32[]. let b:f32[] = sub a 2.0 in (b,) }
{ lambda ; a:f32[]. let b:f32[] = add a 3.0 in (b,) }
)
linear=(False,)
] d b
in (e,) }
The cond primitive has a number of parameters:
@ -250,20 +247,18 @@ Another example, using :py:func:`lax.cond`:
... arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a.
let b = ge a 0.0
c = convert_element_type[ new_dtype=int32
weak_type=False ] b
d = cond[ branches=( { lambda ; a.
let b = sub a 3.0
in (b,) }
{ lambda ; a.
let b = add a 3.0
in (b,) } )
linear=(False,) ] c a
{ lambda ; a:f32[]. let
b:bool[] = ge a 0.0
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:f32[] = cond[
branches=(
{ lambda ; a:f32[]. let b:f32[] = sub a 3.0 in (b,) }
{ lambda ; a:f32[]. let b:f32[] = add a 3.0 in (b,) }
)
linear=(False,)
] c a
in (d,) }
In this case, the boolean predicate is converted to an integer index
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
true branch functionals, in that order. Again, each functional takes
@ -281,24 +276,23 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a ; b c d.
let e = ge b 0.0
f = convert_element_type[ new_dtype=int32
weak_type=False ] e
g = cond[ branches=( { lambda ; a b c.
let d = convert_element_type[ new_dtype=float32
weak_type=True ] a
e = add d c
in (e,) }
{ lambda ; f_ a b.
let
in (a,) } )
linear=(False, False, False) ] f a c d
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
e:bool[] = ge b 0.0
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
g:f32[1] = cond[
branches=(
{ lambda ; a:i32[1] b:f32[1] c:f32[]. let
d:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] a
e:f32[1] = add d c
in (e,) }
{ lambda ; f_:i32[1] a:f32[1] b:f32[]. let in (a,) }
)
linear=(False, False, False)
] f a c d
in (g,) }
While
^^^^^
@ -324,21 +318,22 @@ For example, here is an example fori loop
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d = add a c
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
let f = add c 1
g = mul a 3.0
h = add e g
i = add h b
in (f, d, i) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in (d,) }
cond_nconsts=0 ] c a 0 b d
{ lambda ; a:f32[16] b:i32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[16] = add a c
_:* _:* e:f32[16] = while[
body_jaxpr={ lambda ; a:f32[16] b:f32[16] c:i32[] d:i32[] e:f32[16]. let
f:i32[] = add c 1
g:f32[16] = mul a 3.0
h:f32[16] = add e g
i:f32[16] = add h b
in (f, d, i) }
body_nconsts=2
cond_jaxpr={ lambda ; a:i32[] b:i32[] c:f32[16]. let
d:bool[] = lt a b
in (d,) }
cond_nconsts=0
] c a 0 b d
in (e,) }
The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
@ -372,22 +367,22 @@ For the example consider the function ``func11`` below
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d e = scan[ jaxpr={ lambda ; a b c d.
let e = mul c d
f = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
h = add g a
in (h, b) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1 ] b 0.0 a c
{ lambda ; a:f32[16] b:f32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[] e:f32[16] = scan[
jaxpr={ lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let
e:f32[] = mul c d
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
g:f32[] = add f e
h:f32[] = add g a
in (h, b) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1
] b 0.0 a c
in (d, e) }
The ``linear`` parameter describes for each of the input variables whether they
@ -416,24 +411,23 @@ which the computation should run. For example
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda ; a.
let b = sub a 2.0
c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = mul a c
e = convert_element_type[ new_dtype=float32
weak_type=False ] b
f = add e d
in (f,) }
device=None
donated_invars=(False, False)
inline=False
name=inner ] a b
d = convert_element_type[ new_dtype=float32
weak_type=False ] a
e = add d c
{ lambda ; a:f32[]. let
b:f32[] = sub a 2.0
c:f32[1] = xla_call[
backend=None
call_jaxpr={ lambda ; a:f32[] b:f32[]. let
c:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
d:f32[1] = mul a c
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
f:f32[1] = add e d
in (f,) }
device=None
donated_invars=(False, False)
inline=False
name=inner
] a b
d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
e:f32[1] = add d c
in (e,) }
@ -452,27 +446,27 @@ captured using the ``xla_pmap`` primitive. Consider this example
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a b.
let c = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; a b.
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
e = add c d
f = psum[ axes=('rows',)
axis_index_groups=None ] b
g = div e f
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(None, 0)
name=inner
out_axes=(0,) ] b a
in (c,) }
{ lambda ; a:f32[1,3] b:f32[]. let
c:f32[1,3] = xla_pmap[
axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; a:f32[] b:f32[3]. let
c:f32[3] = add b a
d:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
e:f32[3] = add c d
f:f32[3] = psum[axes=('rows',) axis_index_groups=None] b
g:f32[3] = div e f
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(None, 0)
name=inner
out_axes=(0,)
] b a
in (c,) }
The ``xla_pmap`` primitive specifies the name of the axis (parameter
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``

View File

@ -2295,19 +2295,16 @@ def make_jaxpr(fun: Callable,
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a.
let b = cos a
c = sin b
in (c,) }
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a.
let b = cos a
c = sin a
_ = sin b
d = cos b
e = mul 1.0 d
f = neg e
g = mul f c
{ lambda ; a:f32[]. let
b:f32[] = cos a
c:f32[] = sin a
_:* = sin b
d:f32[] = cos b
e:f32[] = mul 1.0 d
f:f32[] = neg e
g:f32[] = mul f c
in (g,) }
"""
_check_callable(fun)

View File

@ -1,63 +0,0 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator as op
from . import traceback_util
traceback_util.register_exclusion(__file__)
class PrettyPrint:
"""Crude Hughes-inspired pretty printer."""
def __init__(self, lines):
self.lines = lines
def indent(self, indent):
return PrettyPrint([(indent + orig_indent, s)
for orig_indent, s in self.lines])
def annotate(self, length, msg):
(i, s), *rest = self.lines
return PrettyPrint([(i, s.ljust(length) + f" [{msg}]")] + list(rest))
def __add__(self, rhs):
return PrettyPrint(self.lines + rhs.lines)
def __rshift__(self, rhs):
if not rhs.lines:
return self
if not self.lines:
return rhs
indent, s = self.lines[-1]
indented_block = rhs.indent(indent + len(s))
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
return PrettyPrint(self.lines[:-1]
+ [(indent, common_line)]
+ indented_block.lines[1:])
def __str__(self):
return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s):
return PrettyPrint([(0, line) for line in str(s).splitlines()])
def hcat(ps):
return functools.reduce(op.rshift, ps)
def vcat(ps):
return sum(ps, pp(''))

390
jax/_src/pretty_printer.py Normal file
View File

@ -0,0 +1,390 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Wadler-Lindig pretty printer.
#
# References:
# Wadler, P., 1998. A prettier printer. Journal of Functional Programming,
# pp.223-244.
#
# Lindig, C. 2000. Strictly Pretty.
# https://lindig.github.io/papers/strictly-pretty-2000.pdf
#
# Hafiz, A. 2021. Strictly Annotated: A Pretty-Printer With Support for
# Annotations. https://ayazhafiz.com/articles/21/strictly-annotated
#
import abc
import enum
from functools import partial
from typing import List, NamedTuple, Optional, Sequence, Tuple, Union
try:
import colorama # pytype: disable=import-error
except ImportError:
colorama = None
class Doc(abc.ABC):
__slots__ = ()
def format(self, width: int = 80, use_color: bool = False,
annotation_prefix=" # ") -> str:
return _format(self, width, use_color=use_color,
annotation_prefix=annotation_prefix)
def __str__(self):
return self.format()
def __add__(self, other: 'Doc') -> 'Doc':
return concat([self, other])
class _NilDoc(Doc):
def __repr__(self): return "nil"
_nil = _NilDoc()
class _TextDoc(Doc):
__slots__ = ("text", "annotation")
text: str
annotation: Optional[str]
def __init__(self, text: str, annotation: Optional[str] = None):
assert isinstance(text, str), text
assert annotation is None or isinstance(annotation, str), annotation
self.text = text
self.annotation = annotation
def __repr__(self):
if self.annotation is not None:
return f"text(\"{self.text}\", annotation=\"{self.annotation}\")"
else:
return f"text(\"{self.text}\")"
class _ConcatDoc(Doc):
__slots__ = ("children",)
children: List[Doc]
def __init__(self, children: Sequence[Doc]):
self.children = list(children)
assert all(isinstance(doc, Doc) for doc in self.children), self.children
def __repr__(self): return f"concat({self.children})"
class _BreakDoc(Doc):
__slots__ = ("text",)
text: str
def __init__(self, text: str):
assert isinstance(text, str), text
self.text = text
def __repr__(self): return f"break({self.text})"
class _GroupDoc(Doc):
__slots__ = ("child",)
child: Doc
def __init__(self, child: Doc):
assert isinstance(child, Doc), child
self.child = child
def __repr__(self): return f"group({self.child})"
class _NestDoc(Doc):
__slots__ = ("n", "child",)
n: int
child: Doc
def __init__(self, n: int, child: Doc):
assert isinstance(child, Doc), child
self.n = n
self.child = child
def __repr__(self): return f"nest({self.n, self.child})"
Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE",
"MAGENTA", "CYAN", "WHITE", "RESET"])
Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"])
class _ColorDoc(Doc):
__slots__ = ("foreground", "background", "intensity", "child")
foreground: Optional[Color]
background: Optional[Color]
intensity: Optional[Intensity]
child: Doc
def __init__(self, child: Doc, *, foreground: Optional[Color] = None,
background: Optional[Color] = None,
intensity: Optional[Intensity] = None):
assert isinstance(child, Doc), child
self.child = child
self.foreground = foreground
self.background = background
self.intensity = intensity
_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"])
# In Lindig's paper fits() and format() are defined recursively. This is a
# non-recursive formulation using an explicit stack, necessary because Python
# doesn't have a tail recursion optimization.
def _fits(doc: Doc, width: int, agenda: List[Tuple[int, _BreakMode, Doc]]
) -> bool:
while width >= 0 and len(agenda) > 0:
i, m, doc = agenda.pop()
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
width -= len(doc.text)
elif isinstance(doc, _ConcatDoc):
agenda.extend((i, m, d) for d in reversed(doc.children))
elif isinstance(doc, _BreakDoc):
if m == _BreakMode.BREAK:
return True
width -= len(doc.text)
elif isinstance(doc, _NestDoc):
agenda.append((i + doc.n, m, doc.child))
elif isinstance(doc, _GroupDoc):
agenda.append((i, _BreakMode.FLAT, doc.child))
elif isinstance(doc, _ColorDoc):
agenda.append((i, m, doc.child))
else:
raise ValueError("Invalid document ", doc)
return width >= 0
# Annotation layout: A flat group is sparse if there are no breaks between
# annotations.
def _sparse(doc: Doc) -> bool:
agenda = [doc]
num_annotations = 0
seen_break = False
while len(agenda) > 0:
doc = agenda.pop()
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
if doc.annotation is not None:
if num_annotations >= 1 and seen_break:
return False
num_annotations += 1
elif isinstance(doc, _ConcatDoc):
agenda.extend(reversed(doc.children))
elif isinstance(doc, _BreakDoc):
seen_break = True
elif isinstance(doc, _NestDoc):
agenda.append(doc.child)
elif isinstance(doc, _GroupDoc):
agenda.append(doc.child)
elif isinstance(doc, _ColorDoc):
agenda.append(doc.child)
else:
raise ValueError("Invalid document ", doc)
return True
class _ColorState(NamedTuple):
foreground: Color
background: Color
intensity: Intensity
class _State(NamedTuple):
indent: int
mode: _BreakMode
doc: Doc
color: _ColorState
class _Line(NamedTuple):
text: str
width: int
annotations: Union[Optional[str], List[str]]
def _update_color(use_color: bool, state: _ColorState, update: _ColorState
) -> Tuple[_ColorState, str]:
if not use_color or colorama is None:
return update, ""
color_str = ""
if state.foreground != update.foreground:
color_str += getattr(colorama.Fore, str(update.foreground.name))
if state.background != update.background:
color_str += getattr(colorama.Back, str(update.background.name))
if state.intensity != update.intensity:
color_str += colorama.Style.NORMAL
color_str += getattr(colorama.Style, str(update.intensity.name))
return update, color_str
def _align_annotations(lines):
# TODO: Hafiz also implements a local alignment mode, where groups of lines
# with annotations are aligned together.
maxlen = max(l.width for l in lines)
out = []
for l in lines:
if len(l.annotations) == 0:
out.append(l._replace(annotations=None))
elif len(l.annotations) == 1:
out.append(l._replace(text=l.text + " " * (maxlen - l.width),
annotations=l.annotations[0]))
else:
out.append(l._replace(text=l.text + " " * (maxlen - l.width),
annotations=l.annotations[0]))
for a in l.annotations[1:]:
out.append(_Line(text=" " * maxlen, width=l.width, annotations=a))
return out
def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
lines = []
default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL)
annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM)
color_state = default_colors
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)]
k = 0
line_text = ""
line_annotations = []
while len(agenda) > 0:
i, m, doc, color = agenda.pop()
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
color_state, color_str = _update_color(use_color, color_state, color)
line_text += color_str
line_text += doc.text
if doc.annotation is not None:
line_annotations.append(doc.annotation)
k += len(doc.text)
elif isinstance(doc, _ConcatDoc):
agenda.extend(_State(i, m, d, color)
for d in reversed(doc.children))
elif isinstance(doc, _BreakDoc):
if m == _BreakMode.BREAK:
if len(line_annotations) > 0:
color_state, color_str = _update_color(use_color, color_state,
annotation_colors)
line_text += color_str
lines.append(_Line(line_text, k, line_annotations))
line_text = " " * i
line_annotations = []
k = i
else:
color_state, color_str = _update_color(use_color, color_state, color)
line_text += color_str
line_text += doc.text
k += len(doc.text)
elif isinstance(doc, _NestDoc):
agenda.append(_State(i + doc.n, m, doc.child, color))
elif isinstance(doc, _GroupDoc):
# In Lindig's paper, _fits is passed the remainder of the document.
# I'm pretty sure that's a bug and we care only if the current group fits!
if (_sparse(doc)
and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])):
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color))
else:
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color))
elif isinstance(doc, _ColorDoc):
color = _ColorState(doc.foreground or color.foreground,
doc.background or color.background,
doc.intensity or color.intensity)
agenda.append(_State(i, m, doc.child, color))
else:
raise ValueError("Invalid document ", doc)
if len(line_annotations) > 0:
color_state, color_str = _update_color(use_color, color_state,
annotation_colors)
line_text += color_str
lines.append(_Line(line_text, k, line_annotations))
lines = _align_annotations(lines)
out = "\n".join(
l.text if l.annotations is None
else f"{l.text}{annotation_prefix}{l.annotations}" for l in lines)
color_state, color_str = _update_color(use_color, color_state,
default_colors)
return out + color_str
# Public API.
def nil() -> Doc:
"""An empty document."""
return _nil
def text(s: str, annotation: Optional[str] = None) -> Doc:
"""Literal text."""
return _TextDoc(s, annotation)
def concat(docs: Sequence[Doc]) -> Doc:
"""Concatenation of documents."""
docs = list(docs)
if len(docs) == 1:
return docs[0]
return _ConcatDoc(docs)
def brk(text: str = " ") -> Doc:
"""A break.
Prints either as a newline or as `text`, depending on the enclosing group.
"""
return _BreakDoc(text)
def group(doc: Doc) -> Doc:
"""Layout alternative groups.
Prints the group with its breaks as their text (typically spaces) if the
entire group would fit on the line when printed that way. Otherwise, breaks
inside the group as printed as newlines.
"""
return _GroupDoc(doc)
def nest(n: int, doc: Doc) -> Doc:
"""Increases the indentation level by `n`."""
return _NestDoc(n, doc)
def color(doc: Doc, *, foreground: Optional[Color] = None,
background: Optional[Color] = None,
intensity: Optional[Intensity] = None):
"""ANSI colors.
Overrides the foreground/background/intensity of the text for the child doc.
Requires use_colors=True to be set when printing and the `colorama` package
to be installed; otherwise does nothing.
"""
return _ColorDoc(doc, foreground=foreground, background=background,
intensity=intensity)
dim = partial(color, intensity=Intensity.DIM)
bright = partial(color, intensity=Intensity.BRIGHT)
def join(sep: Doc, docs: Sequence[Doc]) -> Doc:
"""Concatenates `docs`, separated by `sep`."""
docs = list(docs)
if len(docs) == 0:
return nil()
xs = [docs[0]]
for doc in docs[1:]:
xs.append(sep)
xs.append(doc)
return concat(xs)

View File

@ -30,7 +30,7 @@ from jax._src.api import jit, vmap
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
from jax._src.pprint_util import pp, vcat
import jax._src.pretty_printer as pp
from jax._src.util import prod
@ -63,9 +63,10 @@ class PRNGImpl(NamedTuple):
fold_in: Callable
def pprint(self):
return (pp(self.__class__.__name__) >> pp(':')) + vcat([
pp(k) >> pp(' = ') >> pp(v) for k, v in self._asdict().items()
]).indent(2)
return (pp.text(f"{self.__class__.__name__}:") +
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
]))))
# -- PRNG key arrays --
@ -182,9 +183,11 @@ class PRNGKeyArray:
def __repr__(self):
arr_shape = self._shape
pp_keys = pp('shape = ') >> pp(arr_shape)
pp_impl = pp('impl = ') >> self.impl.pprint()
return str(pp('PRNGKeyArray:') + (pp_keys + pp_impl).indent(2))
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
pp_impl = pp.text('impl = ') + self.impl.pprint()
return str(pp.group(
pp.text('PRNGKeyArray:') +
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:

View File

@ -40,7 +40,7 @@ from jax._src import source_info_util
from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod,
tuple_insert, tuple_delete, cache, as_hashable_function,
HashableFunction)
from ._src.pprint_util import pp, vcat, PrettyPrint
import jax._src.pretty_printer as pp
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
@ -77,6 +77,10 @@ class Jaxpr:
return str(pp_jaxpr(self))
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
doc = pp_jaxpr(self, source_info=source_info, print_shapes=print_shapes)
return doc.format(**kw)
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
@ -128,6 +132,10 @@ class ClosedJaxpr:
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
return pp_jaxpr(self.jaxpr, source_info=source_info,
print_shapes=print_shapes).format(**kw)
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
@ -574,13 +582,20 @@ class Tracer:
else:
return attr
def __repr__(self):
base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace))
contents = [(name, pp(repr(attr))) for name, attr in self._contents()]
def _pretty_print(self):
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
else pp.text(repr(attr))) for name, attr in self._contents()]
if contents:
base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload
for name, pp_payload in contents)
return str(base)
base = pp.group(pp.nest(2, pp.concat([
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
pp.text('{} = '.format(name)) + pp_payload
for name, pp_payload in contents])
])))
return base
def __repr__(self):
return self._pretty_print().format()
def _contents(self):
try:
@ -894,7 +909,7 @@ class AbstractValue:
def update(self, **kwargs):
raise NotImplementedError("must override")
def str_short(self):
def str_short(self, short_dtypes=False):
raise NotImplementedError("must override")
class Bot(AbstractValue): pass
@ -910,7 +925,7 @@ class AbstractUnit(AbstractValue):
assert other is abstract_unit, other
return self
def _eq(self, self_traced, other): return get_aval(other) is self
def str_short(self): return '*'
def str_short(self, short_dtypes=False): return '*'
abstract_unit = AbstractUnit()
@ -1005,6 +1020,11 @@ def concrete_or_error(force: Any, val: Any, context=""):
convert_element_type_p = Primitive('convert_element_type')
def _short_dtype_name(dtype):
return (dtype.name.replace('float', 'f').replace('uint', 'u')
.replace('int', 'i').replace('complex', 'c'))
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 2
@ -1057,8 +1077,8 @@ class UnshapedArray(AbstractValue):
else:
raise TypeError(self, other)
def str_short(self) -> str:
return self.dtype.name
def str_short(self, short_dtypes=False) -> str:
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
def strip_weak_type(self):
"""Returns a copy of the aval with weak_type=False."""
@ -1126,13 +1146,14 @@ class ShapedArray(UnshapedArray):
else:
raise TypeError(self, other)
def str_short(self):
def str_short(self, short_dtypes=False):
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
shapestr = ','.join(map(str, self.shape))
if self.named_shape:
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
return f'{self.dtype.name}[{shapestr};{named_shapestr}]'
return f'{dt_str}[{shapestr};{named_shapestr}]'
else:
return f'{self.dtype.name}[{shapestr}]'
return f'{dt_str}[{shapestr}]'
def strip_named_shape(self):
return self.update(named_shape={})
@ -1193,8 +1214,9 @@ class ConcreteArray(ShapedArray):
else:
raise TypeError(self, other)
def str_short(self) -> str:
return f'{self.val}, dtype={self.dtype.name}'
def str_short(self, short_dtypes=False) -> str:
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'
_bool = _nonzero = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
@ -1216,7 +1238,7 @@ class AbstractToken(AbstractValue):
return self
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self): return 'Tok'
def str_short(self, short_dtypes=False): return 'Tok'
def at_least_vspace(self): return self
abstract_token: AbstractToken = AbstractToken()
@ -1954,7 +1976,7 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
except JaxprTypeError as e:
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:", str(pp_eqn(eqn).indent(2)),
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn))),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
@ -2028,81 +2050,104 @@ def check_map(prim, in_avals, params):
# ------------------- Jaxpr printed representation -------------------
def pp_vars(vs: Sequence[Any], print_shapes: bool = False) -> str:
def pp_vars(vs: Sequence[Any], *, print_shapes: bool = False) -> pp.Doc:
if print_shapes:
return ' '.join(f'{v}:{v.aval.str_short()}' for v in vs)
return pp.nest(2, pp.group(
pp.join(pp.brk(), [
pp.text(str(v)) +
pp.dim(pp.text(":" + v.aval.str_short(short_dtypes=True)))
for v in vs
])
))
else:
return ' '.join(map(str, vs))
return pp.nest(2, pp.group(
pp.join(pp.brk(), [pp.text(str(v)) for v in vs])
))
def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
def pp_kv_pair(k:str, v: Any) -> pp.Doc:
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
pp_v = pp_jaxprs(v)
elif isinstance(v, Jaxpr):
pp_v = pp_jaxpr(v)
elif isinstance(v, ClosedJaxpr):
pp_v = pp_jaxpr(v.jaxpr)
else:
pp_v = pp.text(str(v))
return pp.text(f'{k}=') + pp_v
def pp_kv_pairs(kv_pairs) -> pp.Doc:
if not kv_pairs:
return pp.nil()
return pp.group(
pp.nest(2, pp.concat([
pp.text("["), pp.brk(""),
pp.join(pp.brk(), [pp_kv_pair(k, v) for k, v in kv_pairs])
]))
+ pp.brk("") + pp.text("]")
)
def pp_eqn(eqn, *, print_shapes=True, source_info=False) -> pp.Doc:
lhs = pp_vars(eqn.outvars, print_shapes=print_shapes)
annotation = (source_info_util.summarize(eqn.source_info)
if source_info else None)
return pp.concat([
lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name),
pp_kv_pairs(sorted(eqn.params.items())),
pp.text(" ") + pp_vars(eqn.invars)
])
def pp_eqns(eqns, *, print_shapes=True, source_info=False) -> pp.Doc:
return pp.join(
pp.brk("; "),
map(partial(pp_eqn, print_shapes=print_shapes, source_info=source_info),
eqns))
def pp_eqn_compact(primitive_name: str, params: Dict) -> pp.Doc:
filtered_params = {k: v for k, v in params.items()
if (k != 'branches' and
not isinstance(v, (Jaxpr, ClosedJaxpr)))}
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
return pp.text(primitive_name) + pp_kv_pairs(sorted(filtered_params.items()))
def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint:
lhs = pp_vars(eqn.outvars, print_shapes)
pp_lhs = pp(f'{lhs} =')
pp_rhs = (pp(eqn.primitive.name) >>
pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
pp(pp_vars(eqn.invars, print_shapes)))
if len(lhs) <= 6 or print_shapes:
return pp_lhs >> pp(' ') >> pp_rhs
else:
return pp_lhs + pp_rhs.indent(2)
def pp_eqns(eqns: Sequence[JaxprEqn],
source_info: bool = False) -> Sequence[PrettyPrint]:
pps = map(pp_eqn, eqns)
if source_info:
l = max((i + len(s) for x in pps for i, s in x.lines), default=None)
if l is not None:
return [p.annotate(l, source_info_util.summarize(e.source_info))
for e, p in zip(eqns, pps)]
return pps
def pp_jaxpr(jaxpr: Jaxpr, source_info: bool = False) -> PrettyPrint:
pps = pp_eqns(jaxpr.eqns, source_info=source_info)
def pp_jaxpr_skeleton(jaxpr, eqns_pp, *, print_shapes=True) -> pp.Doc:
str_outvars = str(tuple(jaxpr.outvars))
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +
((pp('let ') >> vcat(pps))
+ pp('in {} }}'.format(str_outvars))).indent(2))
return pp.group(pp.nest(2, pp.concat([
pp.text("{ "), pp.bright(pp.text("lambda ")),
pp_vars(jaxpr.constvars, print_shapes=print_shapes),
pp.text("; "), pp_vars(jaxpr.invars, print_shapes=print_shapes),
pp.text(". "), pp.bright(pp.text("let")),
pp.nest(2, pp.brk() + eqns_pp), pp.brk(),
pp.bright(pp.text("in")),
pp.text(f" {str_outvars}")
])) + pp.text(" }"))
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int,
source_info: bool = False) -> PrettyPrint:
def pp_jaxpr(jaxpr, *, print_shapes=True, source_info=False) -> pp.Doc:
pps = pp_eqns(jaxpr.eqns, print_shapes=print_shapes, source_info=source_info)
return pp_jaxpr_skeleton(jaxpr, pps, print_shapes=print_shapes)
def pp_jaxprs(jaxprs) -> pp.Doc:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp.group(pp.nest(2, pp.concat([
pp.text('('), pp.brk(""), pp.join(pp.brk(), map(pp_jaxpr, jaxprs))]))
+ pp.brk("") + pp.text(')')
)
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, print_shapes=True,
source_info: bool = False) -> pp.Doc:
lo = max(lo, 0)
hi = max(lo, min(hi, len(jaxpr.eqns)))
eqns = jaxpr.eqns[lo:hi]
pps = []
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
pps.append(pp('...'))
pps.append(pp.text('...'))
else:
if lo != 0:
pps.append(pp('...'))
pps.extend(pp_eqns(eqns, source_info=source_info))
pps.append(pp.text('...'))
pps.extend(map(partial(pp_eqn, print_shapes=print_shapes,
source_info=source_info), eqns))
if hi != len(jaxpr.eqns):
pps.append(pp('...'))
str_outvars = str(tuple(jaxpr.outvars))
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +
((pp('let ') >> vcat(pps))
+ pp('in {} }}'.format(str_outvars))).indent(2))
def pp_jaxprs(jaxprs) -> PrettyPrint:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )')
def pp_kv_pair(k, v):
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
pp_v = pp_jaxprs(v)
else:
pp_v = pp(v)
return pp(f'{k}=') >> pp_v
def pp_kv_pairs(kv_pairs):
if kv_pairs:
return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]')
else:
return pp('')
pps.append(pp.text('...'))
return pp_jaxpr_skeleton(jaxpr, pp.join(pp.brk("; "), pps),
print_shapes=print_shapes)

View File

@ -23,7 +23,7 @@ from jax._src import dtypes
from jax.core import Var, Literal, Atom, Tracer
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
tuple_delete)
from jax._src.pprint_util import pp, vcat, PrettyPrint
import jax._src.pretty_printer as pp
map = safe_map
zip = safe_zip
@ -177,7 +177,7 @@ class DJaxpr:
def __repr__(self):
return str(pp_djaxpr(self))
def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint:
def pp_djaxpr(jaxpr: DJaxpr) -> pp.Doc:
eqns = map(pp_eqn, jaxpr.eqns)
in_dim_binders = pp_vars(jaxpr.in_dim_binders)
in_binders = pp_vars(jaxpr.in_binders)
@ -185,21 +185,21 @@ def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint:
outs = ', '.join(map(str, jaxpr.outs))
out_dim_types = pp_vars(jaxpr.out_dims)
outs_type = ', '.join(v.aval.str_short() for v in jaxpr.outs)
return (pp(f'{{ lambda {in_dim_binders} ; {in_binders} .')
+ (pp('let ') >> vcat(eqns) +
pp(f'in ( {out_dims} ; {outs} ) '
f': ( {out_dim_types} ; {outs_type} ) }}')).indent(2))
return (pp.text(f'{{ lambda {in_dim_binders} ; {in_binders} .')
+ (pp.text('let ') + pp.nest(2, pp.brk() + pp.join(pp.brk(), eqns)) +
pp.text(f'in ( {out_dims} ; {outs} ) '
f': ( {out_dim_types} ; {outs_type} ) }}')))
def pp_vars(vs: Sequence[Atom]) -> str:
return ', '.join(f'{v}:{v.aval.str_short()}' for v in vs)
def pp_eqn(eqn: core.JaxprEqn) -> PrettyPrint:
def pp_eqn(eqn: core.JaxprEqn) -> pp.Doc:
lhs = pp_vars(eqn.outvars)
pp_lhs = pp(f'{lhs} =')
pp_rhs = (pp(eqn.primitive.name) >>
core.pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
pp(' '.join(map(str, eqn.invars))))
return pp_lhs >> pp(' ') >> pp_rhs
pp_lhs = pp.text(f'{lhs} =')
pp_rhs = (pp.text(eqn.primitive.name) +
core.pp_kv_pairs(sorted(eqn.params.items())) + pp.text(' ') +
pp.text(' '.join(map(str, eqn.invars))))
return pp_lhs + pp.text(' ') + pp_rhs
# Typechecking DJaxprs

View File

@ -458,7 +458,7 @@ from jax import lax
from jax.experimental import pjit
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax._src import pprint_util as ppu
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import util
from jax._src.lib import pytree
@ -747,21 +747,31 @@ def _print_tap_func(
f"{k}: {v}" for k, v in sorted(kwargs.items())
])
def pp_val(arg) -> ppu.PrettyPrint:
def pp_val(arg) -> pp.Doc:
if isinstance(arg, tuple):
return (
ppu.pp("( ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" )"))
return pp.group(pp.concat([
pp.text("( "),
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
pp.text(" )")
]))
elif isinstance(arg, list):
return (
ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]"))
return pp.group(pp.concat([
pp.text("[ "),
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
pp.text(" ]")
]))
elif isinstance(arg, dict):
return (ppu.pp("{ ") >> ppu.vcat([
ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items())
]) >> ppu.pp(" }"))
return pp.group(pp.concat([
pp.text("{ "),
pp.nest(2, pp.join(pp.brk(), [
pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items())
])),
pp.text(" }")
]))
elif isinstance(arg, np.ndarray):
return ppu.pp(np.array2string(arg, threshold=threshold))
return pp.text(np.array2string(arg, threshold=threshold))
else:
return ppu.pp(str(arg))
return pp.text(str(arg))
with _print_tap_lock:
if kv_pairs:

View File

@ -36,7 +36,7 @@ from jax._src.abstract_arrays import (make_shaped_array, array_types)
from ..core import (ConcreteArray, ShapedArray, AbstractToken,
Literal, pp_eqn_compact, raise_to_shaped, abstract_token)
from ..errors import UnexpectedTracerError
from jax._src.pprint_util import pp
import jax._src.pretty_printer as pp
from .._src.util import (partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map,
partition_list)
@ -112,11 +112,12 @@ def make_op_metadata(primitive: core.Primitive,
name_stack: str = "",
source_info: Optional[source_info_util.Traceback] = None
) -> xc.OpMetadata:
tracebacks[str(pp(name_stack) >> pp_eqn_compact(primitive.name, params))] = source_info
eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params))
tracebacks[eqn_str] = source_info
frame = source_info_util.user_frame(source_info) if source_info else None
return xc.OpMetadata(
op_type=primitive.name,
op_name=str(pp(name_stack) >> pp_eqn_compact(primitive.name, params)),
op_name=eqn_str,
source_file=_get_canonical_source_file(frame) if frame else None,
source_line=frame.line_num if frame else None)

View File

@ -4,6 +4,8 @@ disable_error_code = attr-defined
[mypy-absl.*]
ignore_missing_imports = True
[mypy-colorama.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-opt_einsum.*]

View File

@ -3722,15 +3722,10 @@ class JaxprTest(jtu.JaxTestCase):
def test_const(self):
def fun(x):
return (x, 1., np.zeros(1))
return (x, 1., np.zeros(1, dtype=jnp.float32))
expected = """
{ lambda a ; b.
let
in (b, 1.0, a) }
"""
jaxpr = api.make_jaxpr(fun)(0.)
expected = "{ lambda a:f32[1]; b:f32[]. let in (b, 1.0, a) }"
jaxpr = api.make_jaxpr(fun)(jnp.float32(0.))
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
def test_cond(self):
@ -3740,23 +3735,24 @@ class JaxprTest(jtu.JaxTestCase):
lambda xt: xt + x,
x + 2.,
lambda xf: xf - x)
expected = """
{ lambda ; a.
let b = ge a 0.0
c = add a 1.0
d = add a 2.0
e = convert_element_type[ new_dtype=int32
weak_type=False ] b
f = cond[ branches=( { lambda ; e_ a b c.
let d = sub c a
in (d,) }
{ lambda ; a f_ b c.
let d = add b a
in (d,) } )
linear=(False, False, False, False) ] e a a c d
in (f,) }
"""
jaxpr = api.make_jaxpr(f)(3.)
expected = """{ lambda ; a:f32[]. let
b:bool[] = ge a 0.0
c:f32[] = add a 1.0
d:f32[] = add a 2.0
e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
f:f32[] = cond[
branches=(
{ lambda ; e_:f32[] a:f32[] b:f32[] c:f32[]. let
d:f32[] = sub c a
in (d,) }
{ lambda ; a:f32[] f_:f32[] b:f32[] c:f32[]. let
d:f32[] = add b a
in (d,) }
)
linear=(False, False, False, False)
] e a a c d
in (f,) }"""
jaxpr = api.make_jaxpr(f)(jnp.float32(3.))
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
def test_make_jaxpr_static_argnums(self):

View File

@ -411,7 +411,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
return jnp.sin(x) + jnp.cos(x)
def new_jaxpr():
return make_jaxpr(f)(1.).jaxpr
return make_jaxpr(f)(jnp.float32(1.)).jaxpr
# jaxpr is:
#
@ -424,19 +424,21 @@ class JaxprTypeChecks(jtu.JaxTestCase):
# NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'
jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float!
# int, not float!
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.' inconsistently typed as ShapedArray(.*), "
r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .",
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:i32\[\] = sin .",
lambda: core.check_jaxpr(jaxpr))
jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(
np.ones((2, 3), dtype=jnp.float32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.' inconsistently typed as ShapedArray(.*), "
r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .",
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:f32\[2,3\] = sin .",
lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_jit_call(self):

View File

@ -275,8 +275,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( 6.00
9.00 )""", testing_stream.output)
( 6.00 9.00 )""", testing_stream.output)
def test_tap_with_dict_results(self):
def func2(x):
@ -286,8 +285,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(3. * (2. + 3.), func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
{ a=6.00
b=9.00 }""", testing_stream.output)
{ a=6.00 b=9.00 }""", testing_stream.output)
def test_tap_with_result(self):
def func2(x):
@ -298,8 +296,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(3. * 4., func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( 6.00
9.00 )""", testing_stream.output)
( 6.00 9.00 )""", testing_stream.output)
def test_tap_with_result_no_arg(self):
def tap_func(arg, transforms):
@ -337,8 +334,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
assertMultiDeviceOutputEqual(self, """
device: cpu:0
( 6.00
9.00 )""")
( 6.00 9.00 )""")
def test_tap_eval_exception(self):
if not FLAGS.jax_host_callback_outfeed:
@ -375,8 +371,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
assertMultiLineStrippedEqual(self, """
( )
what: second
( 1.00
[] )""", testing_stream.output)
( 1.00 [] )""", testing_stream.output)
def test_tap_jit_simple(self):
jit_fun1 = jax.jit(lambda x: 3. * hcb.id_print(
@ -906,11 +901,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ['jvp'] what: a * 2
( 10.00
0.20 )
( 10.00 0.20 )
transforms: ['jvp'] what: y * 3
( 30.00
0.60 )""", testing_stream.output)
( 30.00 0.60 )""", testing_stream.output)
def test_tap_grad_primal_unused(self):
# The output of id_print is not needed for backwards pass
@ -925,23 +918,27 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
treedef = tree_util.tree_structure(arg)
print(jaxpr)
assertMultiLineStrippedEqual(self, f"""
{{ lambda ; a.
let b = mul a 3.00
c = outside_call[ arg_treedef={treedef}
callback=...
identity=True
transforms=( ) ] b
_ = mul c 2.00
d = mul 1.00 2.00
_ = broadcast_in_dim[ broadcast_dimensions=( )
shape=( ) ] 0.00
e = outside_call[ arg_treedef={treedef}
callback=...
identity=True
transforms=(('jvp',), ('transpose',)) ] d
f = mul e 3.00
in (f,) }}""", jaxpr)
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=()
] b
_:* = mul c 2.00
d:f32[] = mul 1.00 2.00
_:* = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=(('jvp',), ('transpose',))
] d
f:f32[] = mul e 3.00
in (f,) }}""", jaxpr)
assertMultiLineStrippedEqual(self, "", testing_stream.output)
testing_stream.reset()
@ -1016,11 +1013,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: pair
( 10.00
15.00 )
( 10.00 15.00 )
transforms: ['jvp', 'transpose'] what: pair
( 0.00
0.00 )""", testing_stream.output)
( 0.00 0.00 )""", testing_stream.output)
def test_tap_jvp_float0(self):
def f(x, yint):
@ -1042,11 +1037,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: pair
( 5.00
2 )
( 5.00 2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00
False )""", testing_stream.output)
( 2.00 False )""", testing_stream.output)
def test_tap_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
@ -1068,11 +1061,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(dtypes.float0, g[1].dtype)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80]
[11 12 13] )
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00]
[False False False] )""", testing_stream.output)
( [0.00 0.00] [False False False] )""", testing_stream.output)
def test_tap_higher_order_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
@ -1100,8 +1091,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
res = f_jax(x)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80]
[11 12 13] )""", testing_stream.output)
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
testing_stream.reset()
# 1st order
@ -1109,11 +1099,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
res_vjp1 = f_jax_vjp1(*args_vjp1)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80]
[11 12 13] )
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00]
[False False False] )""", testing_stream.output)
( [0.00 0.00] [False False False] )""", testing_stream.output)
testing_stream.reset()
# 2nd order
@ -1149,8 +1137,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (None, 0)})]
( 3.00
[4.00 5.00] )""", testing_stream.output)
( 3.00 [4.00 5.00] )""", testing_stream.output)
def test_tap_vmap_vmap(self):
# A 2D tensor with x[i, j] = i + j using 2 vmap
@ -1249,8 +1236,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
expected = """
what: x,x^2
( 3.
9. )"""
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1258,8 +1244,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [0. 1. 2.]
[0. 1. 4.] )"""
( [0. 1. 2.] [0. 1. 4.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1267,10 +1252,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
expected = """
transforms: ['jvp'] what: x,x^2
( ( 3.
9. )
( 0.1
0.6 ) )"""
( ( 3. 9. ) ( 0.1 0.6 ) )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1278,11 +1260,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
expected = """
what: x,x^2
( 3.
9. )
( 3. 9. )
transforms: ['jvp', 'transpose'] what: x,x^2
( 0.
3. )"""
( 0. 3. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@ -1290,11 +1270,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.]
[4. 9.] )
( [2. 3.] [4. 9.] )
transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
( 0.
[2. 3.] )"""
( 0. [2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
@ -1320,11 +1298,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
assertMultiDeviceOutputEqual(
self, """
device: cpu:0 what: x,x^2
( 3
9 )
( 3 9 )
device: cpu:1 what: x,x^2
( 4
16 )""")
( 4 16 )""")
def test_tap_pmap_vmap(self):
# A matrix M[ij] = i * 10 + j
@ -1440,13 +1416,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[ 0.00 2.00 4.00]
[20.00 22.00 24.00]]
[[0.20 0.20 0.20]
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[200.00 202.00 204.00]
[220.00 222.00 224.00]]
[[0.20 0.20 0.20]
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )""")
def test_tap_vmap_pmap(self):
@ -1693,10 +1667,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.]
[0. 2. 4. 6. 8.] )
( ( 3 )
( 3 ) ) )""", testing_stream.output)
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()
# With VMAP
@ -1713,8 +1685,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[5. 6. 7. 8. 9.]]
[[ 0. 2. 4. 6. 8.]
[10. 12. 14. 16. 18.]] )
( ( [3. 4.] )
( [3. 4.] ) ) )""", testing_stream.output)
( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output)
testing_stream.reset()
# With JVP
@ -1724,14 +1695,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
( ( ( [0. 1. 2. 3. 4.]
[0. 2. 4. 6. 8.] )
( ( 3 )
( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4]
[0. 0.2 0.4 0.6 0.8] )
( ( False )
( False ) ) ) )""", testing_stream.output)
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
testing_stream.output)
testing_stream.reset()
# Now with JIT
@ -1739,10 +1705,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.]
[0. 2. 4. 6. 8.] )
( ( 3 )
( 3 ) ) )""", testing_stream.output)
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
def test_tap_callback_delay(self):
hcb.callback_extra = lambda dev: time.sleep(1)

View File

@ -80,7 +80,7 @@ class MetadataTest(jtu.JaxTestCase):
return jax.lax.cond(True, x, true_fun, x, false_fun)
hlo = jax.xla_computation(f)(1.).get_hlo_module().to_string()
self.assertRegex(hlo, 'op_type="cond"')
self.assertRegex(hlo, 'op_name=".*cond\\[ linear=\\(False, False\\) \\]"')
self.assertRegex(hlo, 'op_name=".*cond\\[linear=\\(False, False\\)\\]"')
self.assertRegex(hlo, 'op_type="cos"')
self.assertRegex(hlo, 'op_name=".*cond/branch_0_fun/cos"')
self.assertRegex(hlo, 'op_type="sin"')