diff --git a/build/test-requirements.txt b/build/test-requirements.txt index b49c45f20..4b15db07c 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,3 +1,4 @@ +colorama>=0.4.4 flatbuffers==2.0 pillow>=8.3.1 pytest-benchmark diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index d5eb96d7c..cd89b5618 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -92,11 +92,11 @@ For example, here is the jaxpr produced for the function ``func1`` below ... return jnp.sum(temp) ... >>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a b. - let c = sin b - d = mul c 3.0 - e = add a d - f = reduce_sum[ axes=(0,) ] e +{ lambda ; a:f32[8] b:f32[8]. let + c:f32[8] = sin b + d:f32[8] = mul c 3.0 + e:f32[8] = add a d + f:f32[] = reduce_sum[axes=(0,)] e in (f,) } Here there are no constvars, ``a`` and ``b`` are the input variables @@ -129,11 +129,11 @@ jaxpr as before ... return func2(inner, first, second) ... >>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a b. - let c = sin b - d = mul c 3.0 - e = add a d - f = reduce_sum[ axes=(0,) ] e +{ lambda ; a:f32[8] b:f32[8]. let + c:f32[8] = sin b + d:f32[8] = mul c 3.0 + e:f32[8] = add a d + f:f32[] = reduce_sum[axes=(0,)] e in (f,) } @@ -155,11 +155,11 @@ before (with two input vars, one for each element of the input tuple) ... return jnp.sum(temp) ... >>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))) -{ lambda ; a b. - let c = sin b - d = mul c 3.0 - e = add a d - f = reduce_sum[ axes=(0,) ] e +{ lambda ; a:f32[8] b:f32[8]. let + c:f32[8] = sin b + d:f32[8] = mul c 3.0 + e:f32[8] = add a d + f:f32[] = reduce_sum[axes=(0,)] e in (f,) } @@ -209,20 +209,17 @@ For example: ... arg) ... >>> print(make_jaxpr(one_of_three)(1, 5.)) -{ lambda ; a b. - let c = convert_element_type[ new_dtype=int32 - weak_type=False ] a - d = clamp 0 c 2 - e = cond[ branches=( { lambda ; a. - let b = add a 1.0 - in (b,) } - { lambda ; a. - let b = sub a 2.0 - in (b,) } - { lambda ; a. - let b = add a 3.0 - in (b,) } ) - linear=(False,) ] d b +{ lambda ; a:i32[] b:f32[]. let + c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a + d:i32[] = clamp 0 c 2 + e:f32[] = cond[ + branches=( + { lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) } + { lambda ; a:f32[]. let b:f32[] = sub a 2.0 in (b,) } + { lambda ; a:f32[]. let b:f32[] = add a 3.0 in (b,) } + ) + linear=(False,) + ] d b in (e,) } The cond primitive has a number of parameters: @@ -250,20 +247,18 @@ Another example, using :py:func:`lax.cond`: ... arg) ... >>> print(make_jaxpr(func7)(5.)) -{ lambda ; a. - let b = ge a 0.0 - c = convert_element_type[ new_dtype=int32 - weak_type=False ] b - d = cond[ branches=( { lambda ; a. - let b = sub a 3.0 - in (b,) } - { lambda ; a. - let b = add a 3.0 - in (b,) } ) - linear=(False,) ] c a +{ lambda ; a:f32[]. let + b:bool[] = ge a 0.0 + c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b + d:f32[] = cond[ + branches=( + { lambda ; a:f32[]. let b:f32[] = sub a 3.0 in (b,) } + { lambda ; a:f32[]. let b:f32[] = add a 3.0 in (b,) } + ) + linear=(False,) + ] c a in (d,) } - In this case, the boolean predicate is converted to an integer index (0 or 1), and ``branches`` are jaxprs that correspond to the false and true branch functionals, in that order. Again, each functional takes @@ -281,24 +276,23 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` ... arg2) ... >>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) -{ lambda a ; b c d. - let e = ge b 0.0 - f = convert_element_type[ new_dtype=int32 - weak_type=False ] e - g = cond[ branches=( { lambda ; a b c. - let d = convert_element_type[ new_dtype=float32 - weak_type=True ] a - e = add d c - in (e,) } - { lambda ; f_ a b. - let - in (a,) } ) - linear=(False, False, False) ] f a c d +{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let + e:bool[] = ge b 0.0 + f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e + g:f32[1] = cond[ + branches=( + { lambda ; a:i32[1] b:f32[1] c:f32[]. let + d:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] a + e:f32[1] = add d c + in (e,) } + { lambda ; f_:i32[1] a:f32[1] b:f32[]. let in (a,) } + ) + linear=(False, False, False) + ] f a c d in (g,) } - While ^^^^^ @@ -324,21 +318,22 @@ For example, here is an example fori loop ... arg + ones) ... >>> print(make_jaxpr(func10)(np.ones(16), 5)) -{ lambda ; a b. - let c = broadcast_in_dim[ broadcast_dimensions=( ) - shape=(16,) ] 1.0 - d = add a c - _ _ e = while[ body_jaxpr={ lambda ; a b c d e. - let f = add c 1 - g = mul a 3.0 - h = add e g - i = add h b - in (f, d, i) } - body_nconsts=2 - cond_jaxpr={ lambda ; a b c. - let d = lt a b - in (d,) } - cond_nconsts=0 ] c a 0 b d +{ lambda ; a:f32[16] b:i32[]. let + c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 + d:f32[16] = add a c + _:* _:* e:f32[16] = while[ + body_jaxpr={ lambda ; a:f32[16] b:f32[16] c:i32[] d:i32[] e:f32[16]. let + f:i32[] = add c 1 + g:f32[16] = mul a 3.0 + h:f32[16] = add e g + i:f32[16] = add h b + in (f, d, i) } + body_nconsts=2 + cond_jaxpr={ lambda ; a:i32[] b:i32[] c:f32[16]. let + d:bool[] = lt a b + in (d,) } + cond_nconsts=0 + ] c a 0 b d in (e,) } The while primitive takes 5 arguments: ``c a 0 b d``, as follows: @@ -372,22 +367,22 @@ For the example consider the function ``func11`` below ... return lax.scan(body, 0., (arr, ones)) ... >>> print(make_jaxpr(func11)(np.ones(16), 5.)) -{ lambda ; a b. - let c = broadcast_in_dim[ broadcast_dimensions=( ) - shape=(16,) ] 1.0 - d e = scan[ jaxpr={ lambda ; a b c d. - let e = mul c d - f = convert_element_type[ new_dtype=float32 - weak_type=False ] b - g = add f e - h = add g a - in (h, b) } - length=16 - linear=(False, False, False, False) - num_carry=1 - num_consts=1 - reverse=False - unroll=1 ] b 0.0 a c +{ lambda ; a:f32[16] b:f32[]. let + c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 + d:f32[] e:f32[16] = scan[ + jaxpr={ lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let + e:f32[] = mul c d + f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b + g:f32[] = add f e + h:f32[] = add g a + in (h, b) } + length=16 + linear=(False, False, False, False) + num_carry=1 + num_consts=1 + reverse=False + unroll=1 + ] b 0.0 a c in (d, e) } The ``linear`` parameter describes for each of the input variables whether they @@ -416,24 +411,23 @@ which the computation should run. For example ... return arg + inner(arg - 2.) ... >>> print(make_jaxpr(func12)(1.)) -{ lambda ; a. - let b = sub a 2.0 - c = xla_call[ backend=None - call_jaxpr={ lambda ; a b. - let c = broadcast_in_dim[ broadcast_dimensions=( ) - shape=(1,) ] 1.0 - d = mul a c - e = convert_element_type[ new_dtype=float32 - weak_type=False ] b - f = add e d - in (f,) } - device=None - donated_invars=(False, False) - inline=False - name=inner ] a b - d = convert_element_type[ new_dtype=float32 - weak_type=False ] a - e = add d c +{ lambda ; a:f32[]. let + b:f32[] = sub a 2.0 + c:f32[1] = xla_call[ + backend=None + call_jaxpr={ lambda ; a:f32[] b:f32[]. let + c:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 + d:f32[1] = mul a c + e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b + f:f32[1] = add e d + in (f,) } + device=None + donated_invars=(False, False) + inline=False + name=inner + ] a b + d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a + e:f32[1] = add d c in (e,) } @@ -452,27 +446,27 @@ captured using the ``xla_pmap`` primitive. Consider this example ... return pmap(inner, axis_name='rows')(arr) ... >>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.)) -{ lambda ; a b. - let c = xla_pmap[ axis_name=rows - axis_size=1 - backend=None - call_jaxpr={ lambda ; a b. - let c = add b a - d = broadcast_in_dim[ broadcast_dimensions=( ) - shape=(1,) ] 1.0 - e = add c d - f = psum[ axes=('rows',) - axis_index_groups=None ] b - g = div e f - in (g,) } - devices=None - donated_invars=(False, False) - global_arg_shapes=(None,) - global_axis_size=None - in_axes=(None, 0) - name=inner - out_axes=(0,) ] b a - in (c,) } +{ lambda ; a:f32[1,3] b:f32[]. let + c:f32[1,3] = xla_pmap[ + axis_name=rows + axis_size=1 + backend=None + call_jaxpr={ lambda ; a:f32[] b:f32[3]. let + c:f32[3] = add b a + d:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 + e:f32[3] = add c d + f:f32[3] = psum[axes=('rows',) axis_index_groups=None] b + g:f32[3] = div e f + in (g,) } + devices=None + donated_invars=(False, False) + global_arg_shapes=(None,) + global_axis_size=None + in_axes=(None, 0) + name=inner + out_axes=(0,) + ] b a + in (c,) } The ``xla_pmap`` primitive specifies the name of the axis (parameter ``axis_name``) and the body of the function to be mapped as the ``call_jaxpr`` diff --git a/jax/_src/api.py b/jax/_src/api.py index da98d6e31..596bf9829 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2295,19 +2295,16 @@ def make_jaxpr(fun: Callable, >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) - { lambda ; a. - let b = cos a - c = sin b - in (c,) } + { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) - { lambda ; a. - let b = cos a - c = sin a - _ = sin b - d = cos b - e = mul 1.0 d - f = neg e - g = mul f c + { lambda ; a:f32[]. let + b:f32[] = cos a + c:f32[] = sin a + _:* = sin b + d:f32[] = cos b + e:f32[] = mul 1.0 d + f:f32[] = neg e + g:f32[] = mul f c in (g,) } """ _check_callable(fun) diff --git a/jax/_src/pprint_util.py b/jax/_src/pprint_util.py deleted file mode 100644 index cf6ee4d1b..000000000 --- a/jax/_src/pprint_util.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import operator as op - -from . import traceback_util -traceback_util.register_exclusion(__file__) - - -class PrettyPrint: - """Crude Hughes-inspired pretty printer.""" - - def __init__(self, lines): - self.lines = lines - - def indent(self, indent): - return PrettyPrint([(indent + orig_indent, s) - for orig_indent, s in self.lines]) - - def annotate(self, length, msg): - (i, s), *rest = self.lines - return PrettyPrint([(i, s.ljust(length) + f" [{msg}]")] + list(rest)) - - def __add__(self, rhs): - return PrettyPrint(self.lines + rhs.lines) - - def __rshift__(self, rhs): - if not rhs.lines: - return self - if not self.lines: - return rhs - - indent, s = self.lines[-1] - indented_block = rhs.indent(indent + len(s)) - common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1] - return PrettyPrint(self.lines[:-1] - + [(indent, common_line)] - + indented_block.lines[1:]) - - def __str__(self): - return '\n'.join(' ' * indent + s for indent, s in self.lines) - - -def pp(s): - return PrettyPrint([(0, line) for line in str(s).splitlines()]) - -def hcat(ps): - return functools.reduce(op.rshift, ps) - -def vcat(ps): - return sum(ps, pp('')) diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py new file mode 100644 index 000000000..6d18e897b --- /dev/null +++ b/jax/_src/pretty_printer.py @@ -0,0 +1,390 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Wadler-Lindig pretty printer. +# +# References: +# Wadler, P., 1998. A prettier printer. Journal of Functional Programming, +# pp.223-244. +# +# Lindig, C. 2000. Strictly Pretty. +# https://lindig.github.io/papers/strictly-pretty-2000.pdf +# +# Hafiz, A. 2021. Strictly Annotated: A Pretty-Printer With Support for +# Annotations. https://ayazhafiz.com/articles/21/strictly-annotated +# + +import abc +import enum +from functools import partial +from typing import List, NamedTuple, Optional, Sequence, Tuple, Union + +try: + import colorama # pytype: disable=import-error +except ImportError: + colorama = None + + +class Doc(abc.ABC): + __slots__ = () + + def format(self, width: int = 80, use_color: bool = False, + annotation_prefix=" # ") -> str: + return _format(self, width, use_color=use_color, + annotation_prefix=annotation_prefix) + + def __str__(self): + return self.format() + + def __add__(self, other: 'Doc') -> 'Doc': + return concat([self, other]) + +class _NilDoc(Doc): + def __repr__(self): return "nil" + +_nil = _NilDoc() + +class _TextDoc(Doc): + __slots__ = ("text", "annotation") + text: str + annotation: Optional[str] + + def __init__(self, text: str, annotation: Optional[str] = None): + assert isinstance(text, str), text + assert annotation is None or isinstance(annotation, str), annotation + self.text = text + self.annotation = annotation + + def __repr__(self): + if self.annotation is not None: + return f"text(\"{self.text}\", annotation=\"{self.annotation}\")" + else: + return f"text(\"{self.text}\")" + +class _ConcatDoc(Doc): + __slots__ = ("children",) + children: List[Doc] + + def __init__(self, children: Sequence[Doc]): + self.children = list(children) + assert all(isinstance(doc, Doc) for doc in self.children), self.children + + def __repr__(self): return f"concat({self.children})" + +class _BreakDoc(Doc): + __slots__ = ("text",) + text: str + + def __init__(self, text: str): + assert isinstance(text, str), text + self.text = text + + def __repr__(self): return f"break({self.text})" + +class _GroupDoc(Doc): + __slots__ = ("child",) + child: Doc + + def __init__(self, child: Doc): + assert isinstance(child, Doc), child + self.child = child + + def __repr__(self): return f"group({self.child})" + +class _NestDoc(Doc): + __slots__ = ("n", "child",) + n: int + child: Doc + + def __init__(self, n: int, child: Doc): + assert isinstance(child, Doc), child + self.n = n + self.child = child + + def __repr__(self): return f"nest({self.n, self.child})" + + +Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", + "MAGENTA", "CYAN", "WHITE", "RESET"]) +Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"]) + +class _ColorDoc(Doc): + __slots__ = ("foreground", "background", "intensity", "child") + foreground: Optional[Color] + background: Optional[Color] + intensity: Optional[Intensity] + child: Doc + + def __init__(self, child: Doc, *, foreground: Optional[Color] = None, + background: Optional[Color] = None, + intensity: Optional[Intensity] = None): + assert isinstance(child, Doc), child + self.child = child + self.foreground = foreground + self.background = background + self.intensity = intensity + + +_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) + + +# In Lindig's paper fits() and format() are defined recursively. This is a +# non-recursive formulation using an explicit stack, necessary because Python +# doesn't have a tail recursion optimization. + +def _fits(doc: Doc, width: int, agenda: List[Tuple[int, _BreakMode, Doc]] + ) -> bool: + while width >= 0 and len(agenda) > 0: + i, m, doc = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + width -= len(doc.text) + elif isinstance(doc, _ConcatDoc): + agenda.extend((i, m, d) for d in reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + if m == _BreakMode.BREAK: + return True + width -= len(doc.text) + elif isinstance(doc, _NestDoc): + agenda.append((i + doc.n, m, doc.child)) + elif isinstance(doc, _GroupDoc): + agenda.append((i, _BreakMode.FLAT, doc.child)) + elif isinstance(doc, _ColorDoc): + agenda.append((i, m, doc.child)) + else: + raise ValueError("Invalid document ", doc) + + return width >= 0 + + +# Annotation layout: A flat group is sparse if there are no breaks between +# annotations. +def _sparse(doc: Doc) -> bool: + agenda = [doc] + num_annotations = 0 + seen_break = False + while len(agenda) > 0: + doc = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + if doc.annotation is not None: + if num_annotations >= 1 and seen_break: + return False + num_annotations += 1 + elif isinstance(doc, _ConcatDoc): + agenda.extend(reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + seen_break = True + elif isinstance(doc, _NestDoc): + agenda.append(doc.child) + elif isinstance(doc, _GroupDoc): + agenda.append(doc.child) + elif isinstance(doc, _ColorDoc): + agenda.append(doc.child) + else: + raise ValueError("Invalid document ", doc) + + return True + +class _ColorState(NamedTuple): + foreground: Color + background: Color + intensity: Intensity + +class _State(NamedTuple): + indent: int + mode: _BreakMode + doc: Doc + color: _ColorState + +class _Line(NamedTuple): + text: str + width: int + annotations: Union[Optional[str], List[str]] + + +def _update_color(use_color: bool, state: _ColorState, update: _ColorState + ) -> Tuple[_ColorState, str]: + if not use_color or colorama is None: + return update, "" + color_str = "" + if state.foreground != update.foreground: + color_str += getattr(colorama.Fore, str(update.foreground.name)) + if state.background != update.background: + color_str += getattr(colorama.Back, str(update.background.name)) + if state.intensity != update.intensity: + color_str += colorama.Style.NORMAL + color_str += getattr(colorama.Style, str(update.intensity.name)) + return update, color_str + + +def _align_annotations(lines): + # TODO: Hafiz also implements a local alignment mode, where groups of lines + # with annotations are aligned together. + maxlen = max(l.width for l in lines) + out = [] + for l in lines: + if len(l.annotations) == 0: + out.append(l._replace(annotations=None)) + elif len(l.annotations) == 1: + out.append(l._replace(text=l.text + " " * (maxlen - l.width), + annotations=l.annotations[0])) + else: + out.append(l._replace(text=l.text + " " * (maxlen - l.width), + annotations=l.annotations[0])) + for a in l.annotations[1:]: + out.append(_Line(text=" " * maxlen, width=l.width, annotations=a)) + return out + + + +def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: + lines = [] + default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) + annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) + color_state = default_colors + agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)] + k = 0 + line_text = "" + line_annotations = [] + while len(agenda) > 0: + i, m, doc, color = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + color_state, color_str = _update_color(use_color, color_state, color) + line_text += color_str + line_text += doc.text + if doc.annotation is not None: + line_annotations.append(doc.annotation) + k += len(doc.text) + elif isinstance(doc, _ConcatDoc): + agenda.extend(_State(i, m, d, color) + for d in reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + if m == _BreakMode.BREAK: + if len(line_annotations) > 0: + color_state, color_str = _update_color(use_color, color_state, + annotation_colors) + line_text += color_str + lines.append(_Line(line_text, k, line_annotations)) + line_text = " " * i + line_annotations = [] + k = i + else: + color_state, color_str = _update_color(use_color, color_state, color) + line_text += color_str + line_text += doc.text + k += len(doc.text) + elif isinstance(doc, _NestDoc): + agenda.append(_State(i + doc.n, m, doc.child, color)) + elif isinstance(doc, _GroupDoc): + # In Lindig's paper, _fits is passed the remainder of the document. + # I'm pretty sure that's a bug and we care only if the current group fits! + if (_sparse(doc) + and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): + agenda.append(_State(i, _BreakMode.FLAT, doc.child, color)) + else: + agenda.append(_State(i, _BreakMode.BREAK, doc.child, color)) + elif isinstance(doc, _ColorDoc): + color = _ColorState(doc.foreground or color.foreground, + doc.background or color.background, + doc.intensity or color.intensity) + agenda.append(_State(i, m, doc.child, color)) + else: + raise ValueError("Invalid document ", doc) + + if len(line_annotations) > 0: + color_state, color_str = _update_color(use_color, color_state, + annotation_colors) + line_text += color_str + lines.append(_Line(line_text, k, line_annotations)) + lines = _align_annotations(lines) + out = "\n".join( + l.text if l.annotations is None + else f"{l.text}{annotation_prefix}{l.annotations}" for l in lines) + color_state, color_str = _update_color(use_color, color_state, + default_colors) + return out + color_str + + + + +# Public API. + +def nil() -> Doc: + """An empty document.""" + return _nil + +def text(s: str, annotation: Optional[str] = None) -> Doc: + """Literal text.""" + return _TextDoc(s, annotation) + +def concat(docs: Sequence[Doc]) -> Doc: + """Concatenation of documents.""" + docs = list(docs) + if len(docs) == 1: + return docs[0] + return _ConcatDoc(docs) + +def brk(text: str = " ") -> Doc: + """A break. + + Prints either as a newline or as `text`, depending on the enclosing group. + """ + return _BreakDoc(text) + +def group(doc: Doc) -> Doc: + """Layout alternative groups. + + Prints the group with its breaks as their text (typically spaces) if the + entire group would fit on the line when printed that way. Otherwise, breaks + inside the group as printed as newlines. + """ + return _GroupDoc(doc) + +def nest(n: int, doc: Doc) -> Doc: + """Increases the indentation level by `n`.""" + return _NestDoc(n, doc) + + +def color(doc: Doc, *, foreground: Optional[Color] = None, + background: Optional[Color] = None, + intensity: Optional[Intensity] = None): + """ANSI colors. + + Overrides the foreground/background/intensity of the text for the child doc. + Requires use_colors=True to be set when printing and the `colorama` package + to be installed; otherwise does nothing. + """ + return _ColorDoc(doc, foreground=foreground, background=background, + intensity=intensity) + + +dim = partial(color, intensity=Intensity.DIM) +bright = partial(color, intensity=Intensity.BRIGHT) + + +def join(sep: Doc, docs: Sequence[Doc]) -> Doc: + """Concatenates `docs`, separated by `sep`.""" + docs = list(docs) + if len(docs) == 0: + return nil() + xs = [docs[0]] + for doc in docs[1:]: + xs.append(sep) + xs.append(doc) + return concat(xs) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 323b26615..1f3f61476 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -30,7 +30,7 @@ from jax._src.api import jit, vmap from jax._src.lib import xla_bridge from jax._src.lib import xla_client from jax._src.lib import cuda_prng -from jax._src.pprint_util import pp, vcat +import jax._src.pretty_printer as pp from jax._src.util import prod @@ -63,9 +63,10 @@ class PRNGImpl(NamedTuple): fold_in: Callable def pprint(self): - return (pp(self.__class__.__name__) >> pp(':')) + vcat([ - pp(k) >> pp(' = ') >> pp(v) for k, v in self._asdict().items() - ]).indent(2) + return (pp.text(f"{self.__class__.__name__}:") + + pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [ + pp.text(f"{k} = {v}") for k, v in self._asdict().items() + ])))) # -- PRNG key arrays -- @@ -182,9 +183,11 @@ class PRNGKeyArray: def __repr__(self): arr_shape = self._shape - pp_keys = pp('shape = ') >> pp(arr_shape) - pp_impl = pp('impl = ') >> self.impl.pprint() - return str(pp('PRNGKeyArray:') + (pp_keys + pp_impl).indent(2)) + pp_keys = pp.text('shape = ') + pp.text(str(arr_shape)) + pp_impl = pp.text('impl = ') + self.impl.pprint() + return str(pp.group( + pp.text('PRNGKeyArray:') + + pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl))) def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray: diff --git a/jax/core.py b/jax/core.py index 03b3cad52..f0b43156c 100644 --- a/jax/core.py +++ b/jax/core.py @@ -40,7 +40,7 @@ from jax._src import source_info_util from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod, tuple_insert, tuple_delete, cache, as_hashable_function, HashableFunction) -from ._src.pprint_util import pp, vcat, PrettyPrint +import jax._src.pretty_printer as pp from ._src import traceback_util traceback_util.register_exclusion(__file__) @@ -77,6 +77,10 @@ class Jaxpr: return str(pp_jaxpr(self)) __repr__ = __str__ + def pretty_print(self, *, source_info=False, print_shapes=True, **kw): + doc = pp_jaxpr(self, source_info=source_info, print_shapes=print_shapes) + return doc.format(**kw) + def jaxprs_in_params(params) -> Iterator[Jaxpr]: for val in params.values(): @@ -128,6 +132,10 @@ class ClosedJaxpr: def __str__(self): return str(self.jaxpr) def __repr__(self): return repr(self.jaxpr) + def pretty_print(self, *, source_info=False, print_shapes=True, **kw): + return pp_jaxpr(self.jaxpr, source_info=source_info, + print_shapes=print_shapes).format(**kw) + @curry def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) @@ -574,13 +582,20 @@ class Tracer: else: return attr - def __repr__(self): - base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace)) - contents = [(name, pp(repr(attr))) for name, attr in self._contents()] + def _pretty_print(self): + base = pp.text(f'Traced<{self.aval}>with<{self._trace}>') + contents = [(name, attr._pretty_print() if isinstance(attr, Tracer) + else pp.text(repr(attr))) for name, attr in self._contents()] if contents: - base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload - for name, pp_payload in contents) - return str(base) + base = pp.group(pp.nest(2, pp.concat([ + base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [ + pp.text('{} = '.format(name)) + pp_payload + for name, pp_payload in contents]) + ]))) + return base + + def __repr__(self): + return self._pretty_print().format() def _contents(self): try: @@ -894,7 +909,7 @@ class AbstractValue: def update(self, **kwargs): raise NotImplementedError("must override") - def str_short(self): + def str_short(self, short_dtypes=False): raise NotImplementedError("must override") class Bot(AbstractValue): pass @@ -910,7 +925,7 @@ class AbstractUnit(AbstractValue): assert other is abstract_unit, other return self def _eq(self, self_traced, other): return get_aval(other) is self - def str_short(self): return '*' + def str_short(self, short_dtypes=False): return '*' abstract_unit = AbstractUnit() @@ -1005,6 +1020,11 @@ def concrete_or_error(force: Any, val: Any, context=""): convert_element_type_p = Primitive('convert_element_type') + +def _short_dtype_name(dtype): + return (dtype.name.replace('float', 'f').replace('uint', 'u') + .replace('int', 'i').replace('complex', 'c')) + class UnshapedArray(AbstractValue): __slots__ = ['dtype', 'weak_type'] array_abstraction_level = 2 @@ -1057,8 +1077,8 @@ class UnshapedArray(AbstractValue): else: raise TypeError(self, other) - def str_short(self) -> str: - return self.dtype.name + def str_short(self, short_dtypes=False) -> str: + return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name def strip_weak_type(self): """Returns a copy of the aval with weak_type=False.""" @@ -1126,13 +1146,14 @@ class ShapedArray(UnshapedArray): else: raise TypeError(self, other) - def str_short(self): + def str_short(self, short_dtypes=False): + dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name shapestr = ','.join(map(str, self.shape)) if self.named_shape: named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items()) - return f'{self.dtype.name}[{shapestr};{named_shapestr}]' + return f'{dt_str}[{shapestr};{named_shapestr}]' else: - return f'{self.dtype.name}[{shapestr}]' + return f'{dt_str}[{shapestr}]' def strip_named_shape(self): return self.update(named_shape={}) @@ -1193,8 +1214,9 @@ class ConcreteArray(ShapedArray): else: raise TypeError(self, other) - def str_short(self) -> str: - return f'{self.val}, dtype={self.dtype.name}' + def str_short(self, short_dtypes=False) -> str: + dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + return f'{self.val}, dtype={dt_str}' _bool = _nonzero = partialmethod(_forward_to_value, bool) _int = partialmethod(_forward_to_value, int) @@ -1216,7 +1238,7 @@ class AbstractToken(AbstractValue): return self else: assert False, f"Cannot join {self} with {other}" - def str_short(self): return 'Tok' + def str_short(self, short_dtypes=False): return 'Tok' def at_least_vspace(self): return self abstract_token: AbstractToken = AbstractToken() @@ -1954,7 +1976,7 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]): except JaxprTypeError as e: msg, = e.args src = source_info_util.summarize(eqn.source_info) - msg = "\n\n".join([msg, "in equation:", str(pp_eqn(eqn).indent(2)), + msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn))), f"from source: {src}"]) raise JaxprTypeError(msg, eqn_idx) from None @@ -2028,81 +2050,104 @@ def check_map(prim, in_avals, params): # ------------------- Jaxpr printed representation ------------------- - -def pp_vars(vs: Sequence[Any], print_shapes: bool = False) -> str: +def pp_vars(vs: Sequence[Any], *, print_shapes: bool = False) -> pp.Doc: if print_shapes: - return ' '.join(f'{v}:{v.aval.str_short()}' for v in vs) + return pp.nest(2, pp.group( + pp.join(pp.brk(), [ + pp.text(str(v)) + + pp.dim(pp.text(":" + v.aval.str_short(short_dtypes=True))) + for v in vs + ]) + )) else: - return ' '.join(map(str, vs)) + return pp.nest(2, pp.group( + pp.join(pp.brk(), [pp.text(str(v)) for v in vs]) + )) -def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint: +def pp_kv_pair(k:str, v: Any) -> pp.Doc: + if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v): + pp_v = pp_jaxprs(v) + elif isinstance(v, Jaxpr): + pp_v = pp_jaxpr(v) + elif isinstance(v, ClosedJaxpr): + pp_v = pp_jaxpr(v.jaxpr) + else: + pp_v = pp.text(str(v)) + return pp.text(f'{k}=') + pp_v + +def pp_kv_pairs(kv_pairs) -> pp.Doc: + if not kv_pairs: + return pp.nil() + return pp.group( + pp.nest(2, pp.concat([ + pp.text("["), pp.brk(""), + pp.join(pp.brk(), [pp_kv_pair(k, v) for k, v in kv_pairs]) + ])) + + pp.brk("") + pp.text("]") + ) + +def pp_eqn(eqn, *, print_shapes=True, source_info=False) -> pp.Doc: + lhs = pp_vars(eqn.outvars, print_shapes=print_shapes) + annotation = (source_info_util.summarize(eqn.source_info) + if source_info else None) + return pp.concat([ + lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name), + pp_kv_pairs(sorted(eqn.params.items())), + pp.text(" ") + pp_vars(eqn.invars) + ]) + + +def pp_eqns(eqns, *, print_shapes=True, source_info=False) -> pp.Doc: + return pp.join( + pp.brk("; "), + map(partial(pp_eqn, print_shapes=print_shapes, source_info=source_info), + eqns)) + +def pp_eqn_compact(primitive_name: str, params: Dict) -> pp.Doc: filtered_params = {k: v for k, v in params.items() if (k != 'branches' and not isinstance(v, (Jaxpr, ClosedJaxpr)))} - return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items())) + return pp.text(primitive_name) + pp_kv_pairs(sorted(filtered_params.items())) -def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint: - lhs = pp_vars(eqn.outvars, print_shapes) - pp_lhs = pp(f'{lhs} =') - pp_rhs = (pp(eqn.primitive.name) >> - pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >> - pp(pp_vars(eqn.invars, print_shapes))) - if len(lhs) <= 6 or print_shapes: - return pp_lhs >> pp(' ') >> pp_rhs - else: - return pp_lhs + pp_rhs.indent(2) - -def pp_eqns(eqns: Sequence[JaxprEqn], - source_info: bool = False) -> Sequence[PrettyPrint]: - pps = map(pp_eqn, eqns) - if source_info: - l = max((i + len(s) for x in pps for i, s in x.lines), default=None) - if l is not None: - return [p.annotate(l, source_info_util.summarize(e.source_info)) - for e, p in zip(eqns, pps)] - return pps - -def pp_jaxpr(jaxpr: Jaxpr, source_info: bool = False) -> PrettyPrint: - pps = pp_eqns(jaxpr.eqns, source_info=source_info) +def pp_jaxpr_skeleton(jaxpr, eqns_pp, *, print_shapes=True) -> pp.Doc: str_outvars = str(tuple(jaxpr.outvars)) - return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars), - pp_vars(jaxpr.invars))) + - ((pp('let ') >> vcat(pps)) - + pp('in {} }}'.format(str_outvars))).indent(2)) + return pp.group(pp.nest(2, pp.concat([ + pp.text("{ "), pp.bright(pp.text("lambda ")), + pp_vars(jaxpr.constvars, print_shapes=print_shapes), + pp.text("; "), pp_vars(jaxpr.invars, print_shapes=print_shapes), + pp.text(". "), pp.bright(pp.text("let")), + pp.nest(2, pp.brk() + eqns_pp), pp.brk(), + pp.bright(pp.text("in")), + pp.text(f" {str_outvars}") + ])) + pp.text(" }")) -def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, - source_info: bool = False) -> PrettyPrint: + +def pp_jaxpr(jaxpr, *, print_shapes=True, source_info=False) -> pp.Doc: + pps = pp_eqns(jaxpr.eqns, print_shapes=print_shapes, source_info=source_info) + return pp_jaxpr_skeleton(jaxpr, pps, print_shapes=print_shapes) + +def pp_jaxprs(jaxprs) -> pp.Doc: + jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] + return pp.group(pp.nest(2, pp.concat([ + pp.text('('), pp.brk(""), pp.join(pp.brk(), map(pp_jaxpr, jaxprs))])) + + pp.brk("") + pp.text(')') + ) + + +def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, print_shapes=True, + source_info: bool = False) -> pp.Doc: lo = max(lo, 0) hi = max(lo, min(hi, len(jaxpr.eqns))) eqns = jaxpr.eqns[lo:hi] pps = [] if len(eqns) == 0 and len(jaxpr.eqns) != 0: - pps.append(pp('...')) + pps.append(pp.text('...')) else: if lo != 0: - pps.append(pp('...')) - pps.extend(pp_eqns(eqns, source_info=source_info)) + pps.append(pp.text('...')) + pps.extend(map(partial(pp_eqn, print_shapes=print_shapes, + source_info=source_info), eqns)) if hi != len(jaxpr.eqns): - pps.append(pp('...')) - str_outvars = str(tuple(jaxpr.outvars)) - return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars), - pp_vars(jaxpr.invars))) + - ((pp('let ') >> vcat(pps)) - + pp('in {} }}'.format(str_outvars))).indent(2)) - -def pp_jaxprs(jaxprs) -> PrettyPrint: - jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] - return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )') - -def pp_kv_pair(k, v): - if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v): - pp_v = pp_jaxprs(v) - else: - pp_v = pp(v) - return pp(f'{k}=') >> pp_v - -def pp_kv_pairs(kv_pairs): - if kv_pairs: - return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]') - else: - return pp('') + pps.append(pp.text('...')) + return pp_jaxpr_skeleton(jaxpr, pp.join(pp.brk("; "), pps), + print_shapes=print_shapes) diff --git a/jax/experimental/djax.py b/jax/experimental/djax.py index b16be1bc3..744e0b41d 100644 --- a/jax/experimental/djax.py +++ b/jax/experimental/djax.py @@ -23,7 +23,7 @@ from jax._src import dtypes from jax.core import Var, Literal, Atom, Tracer from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list, tuple_delete) -from jax._src.pprint_util import pp, vcat, PrettyPrint +import jax._src.pretty_printer as pp map = safe_map zip = safe_zip @@ -177,7 +177,7 @@ class DJaxpr: def __repr__(self): return str(pp_djaxpr(self)) -def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint: +def pp_djaxpr(jaxpr: DJaxpr) -> pp.Doc: eqns = map(pp_eqn, jaxpr.eqns) in_dim_binders = pp_vars(jaxpr.in_dim_binders) in_binders = pp_vars(jaxpr.in_binders) @@ -185,21 +185,21 @@ def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint: outs = ', '.join(map(str, jaxpr.outs)) out_dim_types = pp_vars(jaxpr.out_dims) outs_type = ', '.join(v.aval.str_short() for v in jaxpr.outs) - return (pp(f'{{ lambda {in_dim_binders} ; {in_binders} .') - + (pp('let ') >> vcat(eqns) + - pp(f'in ( {out_dims} ; {outs} ) ' - f': ( {out_dim_types} ; {outs_type} ) }}')).indent(2)) + return (pp.text(f'{{ lambda {in_dim_binders} ; {in_binders} .') + + (pp.text('let ') + pp.nest(2, pp.brk() + pp.join(pp.brk(), eqns)) + + pp.text(f'in ( {out_dims} ; {outs} ) ' + f': ( {out_dim_types} ; {outs_type} ) }}'))) def pp_vars(vs: Sequence[Atom]) -> str: return ', '.join(f'{v}:{v.aval.str_short()}' for v in vs) -def pp_eqn(eqn: core.JaxprEqn) -> PrettyPrint: +def pp_eqn(eqn: core.JaxprEqn) -> pp.Doc: lhs = pp_vars(eqn.outvars) - pp_lhs = pp(f'{lhs} =') - pp_rhs = (pp(eqn.primitive.name) >> - core.pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >> - pp(' '.join(map(str, eqn.invars)))) - return pp_lhs >> pp(' ') >> pp_rhs + pp_lhs = pp.text(f'{lhs} =') + pp_rhs = (pp.text(eqn.primitive.name) + + core.pp_kv_pairs(sorted(eqn.params.items())) + pp.text(' ') + + pp.text(' '.join(map(str, eqn.invars)))) + return pp_lhs + pp.text(' ') + pp_rhs # Typechecking DJaxprs diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index c5dce6429..c2b46673b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -458,7 +458,7 @@ from jax import lax from jax.experimental import pjit from jax.interpreters import ad, xla, batching, masking, pxla from jax.interpreters import partial_eval as pe -from jax._src import pprint_util as ppu +from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import util from jax._src.lib import pytree @@ -747,21 +747,31 @@ def _print_tap_func( f"{k}: {v}" for k, v in sorted(kwargs.items()) ]) - def pp_val(arg) -> ppu.PrettyPrint: + def pp_val(arg) -> pp.Doc: if isinstance(arg, tuple): - return ( - ppu.pp("( ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" )")) + return pp.group(pp.concat([ + pp.text("( "), + pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])), + pp.text(" )") + ])) elif isinstance(arg, list): - return ( - ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]")) + return pp.group(pp.concat([ + pp.text("[ "), + pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])), + pp.text(" ]") + ])) elif isinstance(arg, dict): - return (ppu.pp("{ ") >> ppu.vcat([ - ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items()) - ]) >> ppu.pp(" }")) + return pp.group(pp.concat([ + pp.text("{ "), + pp.nest(2, pp.join(pp.brk(), [ + pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items()) + ])), + pp.text(" }") + ])) elif isinstance(arg, np.ndarray): - return ppu.pp(np.array2string(arg, threshold=threshold)) + return pp.text(np.array2string(arg, threshold=threshold)) else: - return ppu.pp(str(arg)) + return pp.text(str(arg)) with _print_tap_lock: if kv_pairs: diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index a4279c61f..134570540 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -36,7 +36,7 @@ from jax._src.abstract_arrays import (make_shaped_array, array_types) from ..core import (ConcreteArray, ShapedArray, AbstractToken, Literal, pp_eqn_compact, raise_to_shaped, abstract_token) from ..errors import UnexpectedTracerError -from jax._src.pprint_util import pp +import jax._src.pretty_printer as pp from .._src.util import (partialmethod, cache, prod, unzip2, extend_name_stack, wrap_name, safe_zip, safe_map, partition_list) @@ -112,11 +112,12 @@ def make_op_metadata(primitive: core.Primitive, name_stack: str = "", source_info: Optional[source_info_util.Traceback] = None ) -> xc.OpMetadata: - tracebacks[str(pp(name_stack) >> pp_eqn_compact(primitive.name, params))] = source_info + eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params)) + tracebacks[eqn_str] = source_info frame = source_info_util.user_frame(source_info) if source_info else None return xc.OpMetadata( op_type=primitive.name, - op_name=str(pp(name_stack) >> pp_eqn_compact(primitive.name, params)), + op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None) diff --git a/mypy.ini b/mypy.ini index b03142e88..7e8e22d01 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,6 +4,8 @@ disable_error_code = attr-defined [mypy-absl.*] ignore_missing_imports = True +[mypy-colorama.*] +ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True [mypy-opt_einsum.*] diff --git a/tests/api_test.py b/tests/api_test.py index 1b94983dc..255bfac1f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3722,15 +3722,10 @@ class JaxprTest(jtu.JaxTestCase): def test_const(self): def fun(x): - return (x, 1., np.zeros(1)) + return (x, 1., np.zeros(1, dtype=jnp.float32)) - expected = """ - { lambda a ; b. - let - in (b, 1.0, a) } - """ - - jaxpr = api.make_jaxpr(fun)(0.) + expected = "{ lambda a:f32[1]; b:f32[]. let in (b, 1.0, a) }" + jaxpr = api.make_jaxpr(fun)(jnp.float32(0.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) def test_cond(self): @@ -3740,23 +3735,24 @@ class JaxprTest(jtu.JaxTestCase): lambda xt: xt + x, x + 2., lambda xf: xf - x) - expected = """ - { lambda ; a. - let b = ge a 0.0 - c = add a 1.0 - d = add a 2.0 - e = convert_element_type[ new_dtype=int32 - weak_type=False ] b - f = cond[ branches=( { lambda ; e_ a b c. - let d = sub c a - in (d,) } - { lambda ; a f_ b c. - let d = add b a - in (d,) } ) - linear=(False, False, False, False) ] e a a c d - in (f,) } - """ - jaxpr = api.make_jaxpr(f)(3.) + expected = """{ lambda ; a:f32[]. let + b:bool[] = ge a 0.0 + c:f32[] = add a 1.0 + d:f32[] = add a 2.0 + e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b + f:f32[] = cond[ + branches=( + { lambda ; e_:f32[] a:f32[] b:f32[] c:f32[]. let + d:f32[] = sub c a + in (d,) } + { lambda ; a:f32[] f_:f32[] b:f32[] c:f32[]. let + d:f32[] = add b a + in (d,) } + ) + linear=(False, False, False, False) + ] e a a c d + in (f,) }""" + jaxpr = api.make_jaxpr(f)(jnp.float32(3.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) def test_make_jaxpr_static_argnums(self): diff --git a/tests/core_test.py b/tests/core_test.py index e4c21af6d..64416f459 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -411,7 +411,7 @@ class JaxprTypeChecks(jtu.JaxTestCase): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): - return make_jaxpr(f)(1.).jaxpr + return make_jaxpr(f)(jnp.float32(1.)).jaxpr # jaxpr is: # @@ -424,19 +424,21 @@ class JaxprTypeChecks(jtu.JaxTestCase): # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() - jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! + # int, not float! + jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2)) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.' inconsistently typed as ShapedArray(.*), " - r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .", + r"bound as ShapedArray(.*)\n\nin equation:\n\n.:i32\[\] = sin .", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() - jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) + jaxpr.eqns[0].outvars[0].aval = make_shaped_array( + np.ones((2, 3), dtype=jnp.float32)) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.' inconsistently typed as ShapedArray(.*), " - r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .", + r"bound as ShapedArray(.*)\n\nin equation:\n\n.:f32\[2,3\] = sin .", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 42172b90b..8ce954f04 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -275,8 +275,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ - ( 6.00 - 9.00 )""", testing_stream.output) + ( 6.00 9.00 )""", testing_stream.output) def test_tap_with_dict_results(self): def func2(x): @@ -286,8 +285,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertEqual(3. * (2. + 3.), func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ - { a=6.00 - b=9.00 }""", testing_stream.output) + { a=6.00 b=9.00 }""", testing_stream.output) def test_tap_with_result(self): def func2(x): @@ -298,8 +296,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertEqual(3. * 4., func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ - ( 6.00 - 9.00 )""", testing_stream.output) + ( 6.00 9.00 )""", testing_stream.output) def test_tap_with_result_no_arg(self): def tap_func(arg, transforms): @@ -337,8 +334,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() assertMultiDeviceOutputEqual(self, """ device: cpu:0 - ( 6.00 - 9.00 )""") + ( 6.00 9.00 )""") def test_tap_eval_exception(self): if not FLAGS.jax_host_callback_outfeed: @@ -375,8 +371,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): assertMultiLineStrippedEqual(self, """ ( ) what: second - ( 1.00 - [] )""", testing_stream.output) + ( 1.00 [] )""", testing_stream.output) def test_tap_jit_simple(self): jit_fun1 = jax.jit(lambda x: 3. * hcb.id_print( @@ -906,11 +901,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ transforms: ['jvp'] what: a * 2 - ( 10.00 - 0.20 ) + ( 10.00 0.20 ) transforms: ['jvp'] what: y * 3 - ( 30.00 - 0.60 )""", testing_stream.output) + ( 30.00 0.60 )""", testing_stream.output) def test_tap_grad_primal_unused(self): # The output of id_print is not needed for backwards pass @@ -925,23 +918,27 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() treedef = tree_util.tree_structure(arg) + print(jaxpr) assertMultiLineStrippedEqual(self, f""" - {{ lambda ; a. - let b = mul a 3.00 - c = outside_call[ arg_treedef={treedef} - callback=... - identity=True - transforms=( ) ] b - _ = mul c 2.00 - d = mul 1.00 2.00 - _ = broadcast_in_dim[ broadcast_dimensions=( ) - shape=( ) ] 0.00 - e = outside_call[ arg_treedef={treedef} - callback=... - identity=True - transforms=(('jvp',), ('transpose',)) ] d - f = mul e 3.00 - in (f,) }}""", jaxpr) + {{ lambda ; a:f32[]. let + b:f32[] = mul a 3.00 + c:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + identity=True + transforms=() + ] b + _:* = mul c 2.00 + d:f32[] = mul 1.00 2.00 + _:* = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00 + e:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + identity=True + transforms=(('jvp',), ('transpose',)) + ] d + f:f32[] = mul e 3.00 + in (f,) }}""", jaxpr) assertMultiLineStrippedEqual(self, "", testing_stream.output) testing_stream.reset() @@ -1016,11 +1013,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: pair - ( 10.00 - 15.00 ) + ( 10.00 15.00 ) transforms: ['jvp', 'transpose'] what: pair - ( 0.00 - 0.00 )""", testing_stream.output) + ( 0.00 0.00 )""", testing_stream.output) def test_tap_jvp_float0(self): def f(x, yint): @@ -1042,11 +1037,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: pair - ( 5.00 - 2 ) + ( 5.00 2 ) transforms: ['jvp', 'transpose'] what: pair - ( 2.00 - False )""", testing_stream.output) + ( 2.00 False )""", testing_stream.output) def test_tap_grad_float0_result(self): # https://github.com/google/jax/issues/7340 @@ -1068,11 +1061,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertEqual(dtypes.float0, g[1].dtype) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] - [11 12 13] ) + ( [0.70 0.80] [11 12 13] ) transforms: ['jvp', 'transpose'] - ( [0.00 0.00] - [False False False] )""", testing_stream.output) + ( [0.00 0.00] [False False False] )""", testing_stream.output) def test_tap_higher_order_grad_float0_result(self): # https://github.com/google/jax/issues/7340 @@ -1100,8 +1091,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): res = f_jax(x) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] - [11 12 13] )""", testing_stream.output) + ( [0.70 0.80] [11 12 13] )""", testing_stream.output) testing_stream.reset() # 1st order @@ -1109,11 +1099,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): res_vjp1 = f_jax_vjp1(*args_vjp1) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] - [11 12 13] ) + ( [0.70 0.80] [11 12 13] ) transforms: ['jvp', 'transpose'] - ( [0.00 0.00] - [False False False] )""", testing_stream.output) + ( [0.00 0.00] [False False False] )""", testing_stream.output) testing_stream.reset() # 2nd order @@ -1149,8 +1137,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ transforms: [('batch', {'batch_dims': (None, 0)})] - ( 3.00 - [4.00 5.00] )""", testing_stream.output) + ( 3.00 [4.00 5.00] )""", testing_stream.output) def test_tap_vmap_vmap(self): # A 2D tensor with x[i, j] = i + j using 2 vmap @@ -1249,8 +1236,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() expected = """ what: x,x^2 - ( 3. - 9. )""" + ( 3. 9. )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1258,8 +1244,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() expected = """ transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [0. 1. 2.] - [0. 1. 4.] )""" + ( [0. 1. 2.] [0. 1. 4.] )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1267,10 +1252,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() expected = """ transforms: ['jvp'] what: x,x^2 - ( ( 3. - 9. ) - ( 0.1 - 0.6 ) )""" + ( ( 3. 9. ) ( 0.1 0.6 ) )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1278,11 +1260,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() expected = """ what: x,x^2 - ( 3. - 9. ) + ( 3. 9. ) transforms: ['jvp', 'transpose'] what: x,x^2 - ( 0. - 3. )""" + ( 0. 3. )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1290,11 +1270,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() expected = """ transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] - [4. 9.] ) + ( [2. 3.] [4. 9.] ) transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2 - ( 0. - [2. 3.] )""" + ( 0. [2. 3.] )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) @@ -1320,11 +1298,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): assertMultiDeviceOutputEqual( self, """ device: cpu:0 what: x,x^2 - ( 3 - 9 ) + ( 3 9 ) device: cpu:1 what: x,x^2 - ( 4 - 16 )""") + ( 4 16 )""") def test_tap_pmap_vmap(self): # A matrix M[ij] = i * 10 + j @@ -1440,13 +1416,11 @@ class HostCallbackTapTest(jtu.JaxTestCase): assertMultiDeviceOutputEqual(self, """ device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2 ( [[ 0.00 2.00 4.00] - [20.00 22.00 24.00]] - [[0.20 0.20 0.20] + [20.00 22.00 24.00]] [[0.20 0.20 0.20] [0.20 0.20 0.20]] ) device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2 ( [[200.00 202.00 204.00] - [220.00 222.00 224.00]] - [[0.20 0.20 0.20] + [220.00 222.00 224.00]] [[0.20 0.20 0.20] [0.20 0.20 0.20]] )""") def test_tap_vmap_pmap(self): @@ -1693,10 +1667,8 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() self.assertMultiLineStrippedEqual(""" transforms: [('mask', {'logical_shapes': 5})] what: x - ( ( [0. 1. 2. 3. 4.] - [0. 2. 4. 6. 8.] ) - ( ( 3 ) - ( 3 ) ) )""", testing_stream.output) + ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""", + testing_stream.output) testing_stream.reset() # With VMAP @@ -1713,8 +1685,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): [5. 6. 7. 8. 9.]] [[ 0. 2. 4. 6. 8.] [10. 12. 14. 16. 18.]] ) - ( ( [3. 4.] ) - ( [3. 4.] ) ) )""", testing_stream.output) + ( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output) testing_stream.reset() # With JVP @@ -1724,14 +1695,9 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() self.assertMultiLineStrippedEqual(""" transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x - ( ( ( [0. 1. 2. 3. 4.] - [0. 2. 4. 6. 8.] ) - ( ( 3 ) - ( 3 ) ) ) - ( ( [0. 0.1 0.2 0.3 0.4] - [0. 0.2 0.4 0.6 0.8] ) - ( ( False ) - ( False ) ) ) )""", testing_stream.output) + ( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) ) + ( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""", + testing_stream.output) testing_stream.reset() # Now with JIT @@ -1739,10 +1705,8 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() self.assertMultiLineStrippedEqual(""" transforms: [('mask', {'logical_shapes': 5})] what: x - ( ( [0. 1. 2. 3. 4.] - [0. 2. 4. 6. 8.] ) - ( ( 3 ) - ( 3 ) ) )""", testing_stream.output) + ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""", + testing_stream.output) def test_tap_callback_delay(self): hcb.callback_extra = lambda dev: time.sleep(1) diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 30b93bbaf..1fa1840de 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -80,7 +80,7 @@ class MetadataTest(jtu.JaxTestCase): return jax.lax.cond(True, x, true_fun, x, false_fun) hlo = jax.xla_computation(f)(1.).get_hlo_module().to_string() self.assertRegex(hlo, 'op_type="cond"') - self.assertRegex(hlo, 'op_name=".*cond\\[ linear=\\(False, False\\) \\]"') + self.assertRegex(hlo, 'op_name=".*cond\\[linear=\\(False, False\\)\\]"') self.assertRegex(hlo, 'op_type="cos"') self.assertRegex(hlo, 'op_name=".*cond/branch_0_fun/cos"') self.assertRegex(hlo, 'op_type="sin"')