mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
This commit is contained in:
parent
001d20d250
commit
5fa4613e99
@ -1,3 +1,4 @@
|
||||
colorama>=0.4.4
|
||||
flatbuffers==2.0
|
||||
pillow>=8.3.1
|
||||
pytest-benchmark
|
||||
|
244
docs/jaxpr.rst
244
docs/jaxpr.rst
@ -92,11 +92,11 @@ For example, here is the jaxpr produced for the function ``func1`` below
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,) ] e
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
|
||||
Here there are no constvars, ``a`` and ``b`` are the input variables
|
||||
@ -129,11 +129,11 @@ jaxpr as before
|
||||
... return func2(inner, first, second)
|
||||
...
|
||||
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,) ] e
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
|
||||
|
||||
@ -155,11 +155,11 @@ before (with two input vars, one for each element of the input tuple)
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,) ] e
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
|
||||
|
||||
@ -209,20 +209,17 @@ For example:
|
||||
... arg)
|
||||
...
|
||||
>>> print(make_jaxpr(one_of_three)(1, 5.))
|
||||
{ lambda ; a b.
|
||||
let c = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] a
|
||||
d = clamp 0 c 2
|
||||
e = cond[ branches=( { lambda ; a.
|
||||
let b = add a 1.0
|
||||
in (b,) }
|
||||
{ lambda ; a.
|
||||
let b = sub a 2.0
|
||||
in (b,) }
|
||||
{ lambda ; a.
|
||||
let b = add a 3.0
|
||||
in (b,) } )
|
||||
linear=(False,) ] d b
|
||||
{ lambda ; a:i32[] b:f32[]. let
|
||||
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
|
||||
d:i32[] = clamp 0 c 2
|
||||
e:f32[] = cond[
|
||||
branches=(
|
||||
{ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
|
||||
{ lambda ; a:f32[]. let b:f32[] = sub a 2.0 in (b,) }
|
||||
{ lambda ; a:f32[]. let b:f32[] = add a 3.0 in (b,) }
|
||||
)
|
||||
linear=(False,)
|
||||
] d b
|
||||
in (e,) }
|
||||
|
||||
The cond primitive has a number of parameters:
|
||||
@ -250,20 +247,18 @@ Another example, using :py:func:`lax.cond`:
|
||||
... arg)
|
||||
...
|
||||
>>> print(make_jaxpr(func7)(5.))
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] b
|
||||
d = cond[ branches=( { lambda ; a.
|
||||
let b = sub a 3.0
|
||||
in (b,) }
|
||||
{ lambda ; a.
|
||||
let b = add a 3.0
|
||||
in (b,) } )
|
||||
linear=(False,) ] c a
|
||||
{ lambda ; a:f32[]. let
|
||||
b:bool[] = ge a 0.0
|
||||
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
|
||||
d:f32[] = cond[
|
||||
branches=(
|
||||
{ lambda ; a:f32[]. let b:f32[] = sub a 3.0 in (b,) }
|
||||
{ lambda ; a:f32[]. let b:f32[] = add a 3.0 in (b,) }
|
||||
)
|
||||
linear=(False,)
|
||||
] c a
|
||||
in (d,) }
|
||||
|
||||
|
||||
In this case, the boolean predicate is converted to an integer index
|
||||
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
|
||||
true branch functionals, in that order. Again, each functional takes
|
||||
@ -281,24 +276,23 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
|
||||
... arg2)
|
||||
...
|
||||
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
||||
{ lambda a ; b c d.
|
||||
let e = ge b 0.0
|
||||
f = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] e
|
||||
g = cond[ branches=( { lambda ; a b c.
|
||||
let d = convert_element_type[ new_dtype=float32
|
||||
weak_type=True ] a
|
||||
e = add d c
|
||||
in (e,) }
|
||||
{ lambda ; f_ a b.
|
||||
let
|
||||
in (a,) } )
|
||||
linear=(False, False, False) ] f a c d
|
||||
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
|
||||
e:bool[] = ge b 0.0
|
||||
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
|
||||
g:f32[1] = cond[
|
||||
branches=(
|
||||
{ lambda ; a:i32[1] b:f32[1] c:f32[]. let
|
||||
d:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] a
|
||||
e:f32[1] = add d c
|
||||
in (e,) }
|
||||
{ lambda ; f_:i32[1] a:f32[1] b:f32[]. let in (a,) }
|
||||
)
|
||||
linear=(False, False, False)
|
||||
] f a c d
|
||||
in (g,) }
|
||||
|
||||
|
||||
|
||||
|
||||
While
|
||||
^^^^^
|
||||
|
||||
@ -324,21 +318,22 @@ For example, here is an example fori loop
|
||||
... arg + ones)
|
||||
...
|
||||
>>> print(make_jaxpr(func10)(np.ones(16), 5))
|
||||
{ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(16,) ] 1.0
|
||||
d = add a c
|
||||
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
|
||||
let f = add c 1
|
||||
g = mul a 3.0
|
||||
h = add e g
|
||||
i = add h b
|
||||
in (f, d, i) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a b c.
|
||||
let d = lt a b
|
||||
in (d,) }
|
||||
cond_nconsts=0 ] c a 0 b d
|
||||
{ lambda ; a:f32[16] b:i32[]. let
|
||||
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
|
||||
d:f32[16] = add a c
|
||||
_:* _:* e:f32[16] = while[
|
||||
body_jaxpr={ lambda ; a:f32[16] b:f32[16] c:i32[] d:i32[] e:f32[16]. let
|
||||
f:i32[] = add c 1
|
||||
g:f32[16] = mul a 3.0
|
||||
h:f32[16] = add e g
|
||||
i:f32[16] = add h b
|
||||
in (f, d, i) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a:i32[] b:i32[] c:f32[16]. let
|
||||
d:bool[] = lt a b
|
||||
in (d,) }
|
||||
cond_nconsts=0
|
||||
] c a 0 b d
|
||||
in (e,) }
|
||||
|
||||
The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
|
||||
@ -372,22 +367,22 @@ For the example consider the function ``func11`` below
|
||||
... return lax.scan(body, 0., (arr, ones))
|
||||
...
|
||||
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
|
||||
{ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(16,) ] 1.0
|
||||
d e = scan[ jaxpr={ lambda ; a b c d.
|
||||
let e = mul c d
|
||||
f = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] b
|
||||
g = add f e
|
||||
h = add g a
|
||||
in (h, b) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1
|
||||
reverse=False
|
||||
unroll=1 ] b 0.0 a c
|
||||
{ lambda ; a:f32[16] b:f32[]. let
|
||||
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
|
||||
d:f32[] e:f32[16] = scan[
|
||||
jaxpr={ lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let
|
||||
e:f32[] = mul c d
|
||||
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
|
||||
g:f32[] = add f e
|
||||
h:f32[] = add g a
|
||||
in (h, b) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1
|
||||
reverse=False
|
||||
unroll=1
|
||||
] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The ``linear`` parameter describes for each of the input variables whether they
|
||||
@ -416,24 +411,23 @@ which the computation should run. For example
|
||||
... return arg + inner(arg - 2.)
|
||||
...
|
||||
>>> print(make_jaxpr(func12)(1.))
|
||||
{ lambda ; a.
|
||||
let b = sub a 2.0
|
||||
c = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
d = mul a c
|
||||
e = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] b
|
||||
f = add e d
|
||||
in (f,) }
|
||||
device=None
|
||||
donated_invars=(False, False)
|
||||
inline=False
|
||||
name=inner ] a b
|
||||
d = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
e = add d c
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = sub a 2.0
|
||||
c:f32[1] = xla_call[
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; a:f32[] b:f32[]. let
|
||||
c:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
|
||||
d:f32[1] = mul a c
|
||||
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
|
||||
f:f32[1] = add e d
|
||||
in (f,) }
|
||||
device=None
|
||||
donated_invars=(False, False)
|
||||
inline=False
|
||||
name=inner
|
||||
] a b
|
||||
d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
|
||||
e:f32[1] = add d c
|
||||
in (e,) }
|
||||
|
||||
|
||||
@ -452,27 +446,27 @@ captured using the ``xla_pmap`` primitive. Consider this example
|
||||
... return pmap(inner, axis_name='rows')(arr)
|
||||
...
|
||||
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
||||
{ lambda ; a b.
|
||||
let c = xla_pmap[ axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; a b.
|
||||
let c = add b a
|
||||
d = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
e = add c d
|
||||
f = psum[ axes=('rows',)
|
||||
axis_index_groups=None ] b
|
||||
g = div e f
|
||||
in (g,) }
|
||||
devices=None
|
||||
donated_invars=(False, False)
|
||||
global_arg_shapes=(None,)
|
||||
global_axis_size=None
|
||||
in_axes=(None, 0)
|
||||
name=inner
|
||||
out_axes=(0,) ] b a
|
||||
in (c,) }
|
||||
{ lambda ; a:f32[1,3] b:f32[]. let
|
||||
c:f32[1,3] = xla_pmap[
|
||||
axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; a:f32[] b:f32[3]. let
|
||||
c:f32[3] = add b a
|
||||
d:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
|
||||
e:f32[3] = add c d
|
||||
f:f32[3] = psum[axes=('rows',) axis_index_groups=None] b
|
||||
g:f32[3] = div e f
|
||||
in (g,) }
|
||||
devices=None
|
||||
donated_invars=(False, False)
|
||||
global_arg_shapes=(None,)
|
||||
global_axis_size=None
|
||||
in_axes=(None, 0)
|
||||
name=inner
|
||||
out_axes=(0,)
|
||||
] b a
|
||||
in (c,) }
|
||||
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter
|
||||
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
|
||||
|
@ -2295,19 +2295,16 @@ def make_jaxpr(fun: Callable,
|
||||
>>> print(f(3.0))
|
||||
-0.83602
|
||||
>>> jax.make_jaxpr(f)(3.0)
|
||||
{ lambda ; a.
|
||||
let b = cos a
|
||||
c = sin b
|
||||
in (c,) }
|
||||
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
||||
>>> jax.make_jaxpr(jax.grad(f))(3.0)
|
||||
{ lambda ; a.
|
||||
let b = cos a
|
||||
c = sin a
|
||||
_ = sin b
|
||||
d = cos b
|
||||
e = mul 1.0 d
|
||||
f = neg e
|
||||
g = mul f c
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = cos a
|
||||
c:f32[] = sin a
|
||||
_:* = sin b
|
||||
d:f32[] = cos b
|
||||
e:f32[] = mul 1.0 d
|
||||
f:f32[] = neg e
|
||||
g:f32[] = mul f c
|
||||
in (g,) }
|
||||
"""
|
||||
_check_callable(fun)
|
||||
|
@ -1,63 +0,0 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import operator as op
|
||||
|
||||
from . import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
class PrettyPrint:
|
||||
"""Crude Hughes-inspired pretty printer."""
|
||||
|
||||
def __init__(self, lines):
|
||||
self.lines = lines
|
||||
|
||||
def indent(self, indent):
|
||||
return PrettyPrint([(indent + orig_indent, s)
|
||||
for orig_indent, s in self.lines])
|
||||
|
||||
def annotate(self, length, msg):
|
||||
(i, s), *rest = self.lines
|
||||
return PrettyPrint([(i, s.ljust(length) + f" [{msg}]")] + list(rest))
|
||||
|
||||
def __add__(self, rhs):
|
||||
return PrettyPrint(self.lines + rhs.lines)
|
||||
|
||||
def __rshift__(self, rhs):
|
||||
if not rhs.lines:
|
||||
return self
|
||||
if not self.lines:
|
||||
return rhs
|
||||
|
||||
indent, s = self.lines[-1]
|
||||
indented_block = rhs.indent(indent + len(s))
|
||||
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
|
||||
return PrettyPrint(self.lines[:-1]
|
||||
+ [(indent, common_line)]
|
||||
+ indented_block.lines[1:])
|
||||
|
||||
def __str__(self):
|
||||
return '\n'.join(' ' * indent + s for indent, s in self.lines)
|
||||
|
||||
|
||||
def pp(s):
|
||||
return PrettyPrint([(0, line) for line in str(s).splitlines()])
|
||||
|
||||
def hcat(ps):
|
||||
return functools.reduce(op.rshift, ps)
|
||||
|
||||
def vcat(ps):
|
||||
return sum(ps, pp(''))
|
390
jax/_src/pretty_printer.py
Normal file
390
jax/_src/pretty_printer.py
Normal file
@ -0,0 +1,390 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Wadler-Lindig pretty printer.
|
||||
#
|
||||
# References:
|
||||
# Wadler, P., 1998. A prettier printer. Journal of Functional Programming,
|
||||
# pp.223-244.
|
||||
#
|
||||
# Lindig, C. 2000. Strictly Pretty.
|
||||
# https://lindig.github.io/papers/strictly-pretty-2000.pdf
|
||||
#
|
||||
# Hafiz, A. 2021. Strictly Annotated: A Pretty-Printer With Support for
|
||||
# Annotations. https://ayazhafiz.com/articles/21/strictly-annotated
|
||||
#
|
||||
|
||||
import abc
|
||||
import enum
|
||||
from functools import partial
|
||||
from typing import List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
try:
|
||||
import colorama # pytype: disable=import-error
|
||||
except ImportError:
|
||||
colorama = None
|
||||
|
||||
|
||||
class Doc(abc.ABC):
|
||||
__slots__ = ()
|
||||
|
||||
def format(self, width: int = 80, use_color: bool = False,
|
||||
annotation_prefix=" # ") -> str:
|
||||
return _format(self, width, use_color=use_color,
|
||||
annotation_prefix=annotation_prefix)
|
||||
|
||||
def __str__(self):
|
||||
return self.format()
|
||||
|
||||
def __add__(self, other: 'Doc') -> 'Doc':
|
||||
return concat([self, other])
|
||||
|
||||
class _NilDoc(Doc):
|
||||
def __repr__(self): return "nil"
|
||||
|
||||
_nil = _NilDoc()
|
||||
|
||||
class _TextDoc(Doc):
|
||||
__slots__ = ("text", "annotation")
|
||||
text: str
|
||||
annotation: Optional[str]
|
||||
|
||||
def __init__(self, text: str, annotation: Optional[str] = None):
|
||||
assert isinstance(text, str), text
|
||||
assert annotation is None or isinstance(annotation, str), annotation
|
||||
self.text = text
|
||||
self.annotation = annotation
|
||||
|
||||
def __repr__(self):
|
||||
if self.annotation is not None:
|
||||
return f"text(\"{self.text}\", annotation=\"{self.annotation}\")"
|
||||
else:
|
||||
return f"text(\"{self.text}\")"
|
||||
|
||||
class _ConcatDoc(Doc):
|
||||
__slots__ = ("children",)
|
||||
children: List[Doc]
|
||||
|
||||
def __init__(self, children: Sequence[Doc]):
|
||||
self.children = list(children)
|
||||
assert all(isinstance(doc, Doc) for doc in self.children), self.children
|
||||
|
||||
def __repr__(self): return f"concat({self.children})"
|
||||
|
||||
class _BreakDoc(Doc):
|
||||
__slots__ = ("text",)
|
||||
text: str
|
||||
|
||||
def __init__(self, text: str):
|
||||
assert isinstance(text, str), text
|
||||
self.text = text
|
||||
|
||||
def __repr__(self): return f"break({self.text})"
|
||||
|
||||
class _GroupDoc(Doc):
|
||||
__slots__ = ("child",)
|
||||
child: Doc
|
||||
|
||||
def __init__(self, child: Doc):
|
||||
assert isinstance(child, Doc), child
|
||||
self.child = child
|
||||
|
||||
def __repr__(self): return f"group({self.child})"
|
||||
|
||||
class _NestDoc(Doc):
|
||||
__slots__ = ("n", "child",)
|
||||
n: int
|
||||
child: Doc
|
||||
|
||||
def __init__(self, n: int, child: Doc):
|
||||
assert isinstance(child, Doc), child
|
||||
self.n = n
|
||||
self.child = child
|
||||
|
||||
def __repr__(self): return f"nest({self.n, self.child})"
|
||||
|
||||
|
||||
Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE",
|
||||
"MAGENTA", "CYAN", "WHITE", "RESET"])
|
||||
Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"])
|
||||
|
||||
class _ColorDoc(Doc):
|
||||
__slots__ = ("foreground", "background", "intensity", "child")
|
||||
foreground: Optional[Color]
|
||||
background: Optional[Color]
|
||||
intensity: Optional[Intensity]
|
||||
child: Doc
|
||||
|
||||
def __init__(self, child: Doc, *, foreground: Optional[Color] = None,
|
||||
background: Optional[Color] = None,
|
||||
intensity: Optional[Intensity] = None):
|
||||
assert isinstance(child, Doc), child
|
||||
self.child = child
|
||||
self.foreground = foreground
|
||||
self.background = background
|
||||
self.intensity = intensity
|
||||
|
||||
|
||||
_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"])
|
||||
|
||||
|
||||
# In Lindig's paper fits() and format() are defined recursively. This is a
|
||||
# non-recursive formulation using an explicit stack, necessary because Python
|
||||
# doesn't have a tail recursion optimization.
|
||||
|
||||
def _fits(doc: Doc, width: int, agenda: List[Tuple[int, _BreakMode, Doc]]
|
||||
) -> bool:
|
||||
while width >= 0 and len(agenda) > 0:
|
||||
i, m, doc = agenda.pop()
|
||||
if isinstance(doc, _NilDoc):
|
||||
pass
|
||||
elif isinstance(doc, _TextDoc):
|
||||
width -= len(doc.text)
|
||||
elif isinstance(doc, _ConcatDoc):
|
||||
agenda.extend((i, m, d) for d in reversed(doc.children))
|
||||
elif isinstance(doc, _BreakDoc):
|
||||
if m == _BreakMode.BREAK:
|
||||
return True
|
||||
width -= len(doc.text)
|
||||
elif isinstance(doc, _NestDoc):
|
||||
agenda.append((i + doc.n, m, doc.child))
|
||||
elif isinstance(doc, _GroupDoc):
|
||||
agenda.append((i, _BreakMode.FLAT, doc.child))
|
||||
elif isinstance(doc, _ColorDoc):
|
||||
agenda.append((i, m, doc.child))
|
||||
else:
|
||||
raise ValueError("Invalid document ", doc)
|
||||
|
||||
return width >= 0
|
||||
|
||||
|
||||
# Annotation layout: A flat group is sparse if there are no breaks between
|
||||
# annotations.
|
||||
def _sparse(doc: Doc) -> bool:
|
||||
agenda = [doc]
|
||||
num_annotations = 0
|
||||
seen_break = False
|
||||
while len(agenda) > 0:
|
||||
doc = agenda.pop()
|
||||
if isinstance(doc, _NilDoc):
|
||||
pass
|
||||
elif isinstance(doc, _TextDoc):
|
||||
if doc.annotation is not None:
|
||||
if num_annotations >= 1 and seen_break:
|
||||
return False
|
||||
num_annotations += 1
|
||||
elif isinstance(doc, _ConcatDoc):
|
||||
agenda.extend(reversed(doc.children))
|
||||
elif isinstance(doc, _BreakDoc):
|
||||
seen_break = True
|
||||
elif isinstance(doc, _NestDoc):
|
||||
agenda.append(doc.child)
|
||||
elif isinstance(doc, _GroupDoc):
|
||||
agenda.append(doc.child)
|
||||
elif isinstance(doc, _ColorDoc):
|
||||
agenda.append(doc.child)
|
||||
else:
|
||||
raise ValueError("Invalid document ", doc)
|
||||
|
||||
return True
|
||||
|
||||
class _ColorState(NamedTuple):
|
||||
foreground: Color
|
||||
background: Color
|
||||
intensity: Intensity
|
||||
|
||||
class _State(NamedTuple):
|
||||
indent: int
|
||||
mode: _BreakMode
|
||||
doc: Doc
|
||||
color: _ColorState
|
||||
|
||||
class _Line(NamedTuple):
|
||||
text: str
|
||||
width: int
|
||||
annotations: Union[Optional[str], List[str]]
|
||||
|
||||
|
||||
def _update_color(use_color: bool, state: _ColorState, update: _ColorState
|
||||
) -> Tuple[_ColorState, str]:
|
||||
if not use_color or colorama is None:
|
||||
return update, ""
|
||||
color_str = ""
|
||||
if state.foreground != update.foreground:
|
||||
color_str += getattr(colorama.Fore, str(update.foreground.name))
|
||||
if state.background != update.background:
|
||||
color_str += getattr(colorama.Back, str(update.background.name))
|
||||
if state.intensity != update.intensity:
|
||||
color_str += colorama.Style.NORMAL
|
||||
color_str += getattr(colorama.Style, str(update.intensity.name))
|
||||
return update, color_str
|
||||
|
||||
|
||||
def _align_annotations(lines):
|
||||
# TODO: Hafiz also implements a local alignment mode, where groups of lines
|
||||
# with annotations are aligned together.
|
||||
maxlen = max(l.width for l in lines)
|
||||
out = []
|
||||
for l in lines:
|
||||
if len(l.annotations) == 0:
|
||||
out.append(l._replace(annotations=None))
|
||||
elif len(l.annotations) == 1:
|
||||
out.append(l._replace(text=l.text + " " * (maxlen - l.width),
|
||||
annotations=l.annotations[0]))
|
||||
else:
|
||||
out.append(l._replace(text=l.text + " " * (maxlen - l.width),
|
||||
annotations=l.annotations[0]))
|
||||
for a in l.annotations[1:]:
|
||||
out.append(_Line(text=" " * maxlen, width=l.width, annotations=a))
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
|
||||
lines = []
|
||||
default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL)
|
||||
annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM)
|
||||
color_state = default_colors
|
||||
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)]
|
||||
k = 0
|
||||
line_text = ""
|
||||
line_annotations = []
|
||||
while len(agenda) > 0:
|
||||
i, m, doc, color = agenda.pop()
|
||||
if isinstance(doc, _NilDoc):
|
||||
pass
|
||||
elif isinstance(doc, _TextDoc):
|
||||
color_state, color_str = _update_color(use_color, color_state, color)
|
||||
line_text += color_str
|
||||
line_text += doc.text
|
||||
if doc.annotation is not None:
|
||||
line_annotations.append(doc.annotation)
|
||||
k += len(doc.text)
|
||||
elif isinstance(doc, _ConcatDoc):
|
||||
agenda.extend(_State(i, m, d, color)
|
||||
for d in reversed(doc.children))
|
||||
elif isinstance(doc, _BreakDoc):
|
||||
if m == _BreakMode.BREAK:
|
||||
if len(line_annotations) > 0:
|
||||
color_state, color_str = _update_color(use_color, color_state,
|
||||
annotation_colors)
|
||||
line_text += color_str
|
||||
lines.append(_Line(line_text, k, line_annotations))
|
||||
line_text = " " * i
|
||||
line_annotations = []
|
||||
k = i
|
||||
else:
|
||||
color_state, color_str = _update_color(use_color, color_state, color)
|
||||
line_text += color_str
|
||||
line_text += doc.text
|
||||
k += len(doc.text)
|
||||
elif isinstance(doc, _NestDoc):
|
||||
agenda.append(_State(i + doc.n, m, doc.child, color))
|
||||
elif isinstance(doc, _GroupDoc):
|
||||
# In Lindig's paper, _fits is passed the remainder of the document.
|
||||
# I'm pretty sure that's a bug and we care only if the current group fits!
|
||||
if (_sparse(doc)
|
||||
and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])):
|
||||
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color))
|
||||
else:
|
||||
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color))
|
||||
elif isinstance(doc, _ColorDoc):
|
||||
color = _ColorState(doc.foreground or color.foreground,
|
||||
doc.background or color.background,
|
||||
doc.intensity or color.intensity)
|
||||
agenda.append(_State(i, m, doc.child, color))
|
||||
else:
|
||||
raise ValueError("Invalid document ", doc)
|
||||
|
||||
if len(line_annotations) > 0:
|
||||
color_state, color_str = _update_color(use_color, color_state,
|
||||
annotation_colors)
|
||||
line_text += color_str
|
||||
lines.append(_Line(line_text, k, line_annotations))
|
||||
lines = _align_annotations(lines)
|
||||
out = "\n".join(
|
||||
l.text if l.annotations is None
|
||||
else f"{l.text}{annotation_prefix}{l.annotations}" for l in lines)
|
||||
color_state, color_str = _update_color(use_color, color_state,
|
||||
default_colors)
|
||||
return out + color_str
|
||||
|
||||
|
||||
|
||||
|
||||
# Public API.
|
||||
|
||||
def nil() -> Doc:
|
||||
"""An empty document."""
|
||||
return _nil
|
||||
|
||||
def text(s: str, annotation: Optional[str] = None) -> Doc:
|
||||
"""Literal text."""
|
||||
return _TextDoc(s, annotation)
|
||||
|
||||
def concat(docs: Sequence[Doc]) -> Doc:
|
||||
"""Concatenation of documents."""
|
||||
docs = list(docs)
|
||||
if len(docs) == 1:
|
||||
return docs[0]
|
||||
return _ConcatDoc(docs)
|
||||
|
||||
def brk(text: str = " ") -> Doc:
|
||||
"""A break.
|
||||
|
||||
Prints either as a newline or as `text`, depending on the enclosing group.
|
||||
"""
|
||||
return _BreakDoc(text)
|
||||
|
||||
def group(doc: Doc) -> Doc:
|
||||
"""Layout alternative groups.
|
||||
|
||||
Prints the group with its breaks as their text (typically spaces) if the
|
||||
entire group would fit on the line when printed that way. Otherwise, breaks
|
||||
inside the group as printed as newlines.
|
||||
"""
|
||||
return _GroupDoc(doc)
|
||||
|
||||
def nest(n: int, doc: Doc) -> Doc:
|
||||
"""Increases the indentation level by `n`."""
|
||||
return _NestDoc(n, doc)
|
||||
|
||||
|
||||
def color(doc: Doc, *, foreground: Optional[Color] = None,
|
||||
background: Optional[Color] = None,
|
||||
intensity: Optional[Intensity] = None):
|
||||
"""ANSI colors.
|
||||
|
||||
Overrides the foreground/background/intensity of the text for the child doc.
|
||||
Requires use_colors=True to be set when printing and the `colorama` package
|
||||
to be installed; otherwise does nothing.
|
||||
"""
|
||||
return _ColorDoc(doc, foreground=foreground, background=background,
|
||||
intensity=intensity)
|
||||
|
||||
|
||||
dim = partial(color, intensity=Intensity.DIM)
|
||||
bright = partial(color, intensity=Intensity.BRIGHT)
|
||||
|
||||
|
||||
def join(sep: Doc, docs: Sequence[Doc]) -> Doc:
|
||||
"""Concatenates `docs`, separated by `sep`."""
|
||||
docs = list(docs)
|
||||
if len(docs) == 0:
|
||||
return nil()
|
||||
xs = [docs[0]]
|
||||
for doc in docs[1:]:
|
||||
xs.append(sep)
|
||||
xs.append(doc)
|
||||
return concat(xs)
|
@ -30,7 +30,7 @@ from jax._src.api import jit, vmap
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import cuda_prng
|
||||
from jax._src.pprint_util import pp, vcat
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import prod
|
||||
|
||||
|
||||
@ -63,9 +63,10 @@ class PRNGImpl(NamedTuple):
|
||||
fold_in: Callable
|
||||
|
||||
def pprint(self):
|
||||
return (pp(self.__class__.__name__) >> pp(':')) + vcat([
|
||||
pp(k) >> pp(' = ') >> pp(v) for k, v in self._asdict().items()
|
||||
]).indent(2)
|
||||
return (pp.text(f"{self.__class__.__name__}:") +
|
||||
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
|
||||
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
|
||||
]))))
|
||||
|
||||
|
||||
# -- PRNG key arrays --
|
||||
@ -182,9 +183,11 @@ class PRNGKeyArray:
|
||||
|
||||
def __repr__(self):
|
||||
arr_shape = self._shape
|
||||
pp_keys = pp('shape = ') >> pp(arr_shape)
|
||||
pp_impl = pp('impl = ') >> self.impl.pprint()
|
||||
return str(pp('PRNGKeyArray:') + (pp_keys + pp_impl).indent(2))
|
||||
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
|
||||
pp_impl = pp.text('impl = ') + self.impl.pprint()
|
||||
return str(pp.group(
|
||||
pp.text('PRNGKeyArray:') +
|
||||
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
|
||||
|
||||
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
|
||||
|
203
jax/core.py
203
jax/core.py
@ -40,7 +40,7 @@ from jax._src import source_info_util
|
||||
from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod,
|
||||
tuple_insert, tuple_delete, cache, as_hashable_function,
|
||||
HashableFunction)
|
||||
from ._src.pprint_util import pp, vcat, PrettyPrint
|
||||
import jax._src.pretty_printer as pp
|
||||
|
||||
from ._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -77,6 +77,10 @@ class Jaxpr:
|
||||
return str(pp_jaxpr(self))
|
||||
__repr__ = __str__
|
||||
|
||||
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
|
||||
doc = pp_jaxpr(self, source_info=source_info, print_shapes=print_shapes)
|
||||
return doc.format(**kw)
|
||||
|
||||
|
||||
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
|
||||
for val in params.values():
|
||||
@ -128,6 +132,10 @@ class ClosedJaxpr:
|
||||
def __str__(self): return str(self.jaxpr)
|
||||
def __repr__(self): return repr(self.jaxpr)
|
||||
|
||||
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
|
||||
return pp_jaxpr(self.jaxpr, source_info=source_info,
|
||||
print_shapes=print_shapes).format(**kw)
|
||||
|
||||
@curry
|
||||
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
|
||||
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
||||
@ -574,13 +582,20 @@ class Tracer:
|
||||
else:
|
||||
return attr
|
||||
|
||||
def __repr__(self):
|
||||
base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace))
|
||||
contents = [(name, pp(repr(attr))) for name, attr in self._contents()]
|
||||
def _pretty_print(self):
|
||||
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
|
||||
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
|
||||
else pp.text(repr(attr))) for name, attr in self._contents()]
|
||||
if contents:
|
||||
base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload
|
||||
for name, pp_payload in contents)
|
||||
return str(base)
|
||||
base = pp.group(pp.nest(2, pp.concat([
|
||||
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
|
||||
pp.text('{} = '.format(name)) + pp_payload
|
||||
for name, pp_payload in contents])
|
||||
])))
|
||||
return base
|
||||
|
||||
def __repr__(self):
|
||||
return self._pretty_print().format()
|
||||
|
||||
def _contents(self):
|
||||
try:
|
||||
@ -894,7 +909,7 @@ class AbstractValue:
|
||||
def update(self, **kwargs):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def str_short(self):
|
||||
def str_short(self, short_dtypes=False):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
class Bot(AbstractValue): pass
|
||||
@ -910,7 +925,7 @@ class AbstractUnit(AbstractValue):
|
||||
assert other is abstract_unit, other
|
||||
return self
|
||||
def _eq(self, self_traced, other): return get_aval(other) is self
|
||||
def str_short(self): return '*'
|
||||
def str_short(self, short_dtypes=False): return '*'
|
||||
|
||||
abstract_unit = AbstractUnit()
|
||||
|
||||
@ -1005,6 +1020,11 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
|
||||
convert_element_type_p = Primitive('convert_element_type')
|
||||
|
||||
|
||||
def _short_dtype_name(dtype):
|
||||
return (dtype.name.replace('float', 'f').replace('uint', 'u')
|
||||
.replace('int', 'i').replace('complex', 'c'))
|
||||
|
||||
class UnshapedArray(AbstractValue):
|
||||
__slots__ = ['dtype', 'weak_type']
|
||||
array_abstraction_level = 2
|
||||
@ -1057,8 +1077,8 @@ class UnshapedArray(AbstractValue):
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self) -> str:
|
||||
return self.dtype.name
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
|
||||
def strip_weak_type(self):
|
||||
"""Returns a copy of the aval with weak_type=False."""
|
||||
@ -1126,13 +1146,14 @@ class ShapedArray(UnshapedArray):
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self):
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
if self.named_shape:
|
||||
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
|
||||
return f'{self.dtype.name}[{shapestr};{named_shapestr}]'
|
||||
return f'{dt_str}[{shapestr};{named_shapestr}]'
|
||||
else:
|
||||
return f'{self.dtype.name}[{shapestr}]'
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
|
||||
def strip_named_shape(self):
|
||||
return self.update(named_shape={})
|
||||
@ -1193,8 +1214,9 @@ class ConcreteArray(ShapedArray):
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self) -> str:
|
||||
return f'{self.val}, dtype={self.dtype.name}'
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
return f'{self.val}, dtype={dt_str}'
|
||||
|
||||
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
||||
_int = partialmethod(_forward_to_value, int)
|
||||
@ -1216,7 +1238,7 @@ class AbstractToken(AbstractValue):
|
||||
return self
|
||||
else:
|
||||
assert False, f"Cannot join {self} with {other}"
|
||||
def str_short(self): return 'Tok'
|
||||
def str_short(self, short_dtypes=False): return 'Tok'
|
||||
def at_least_vspace(self): return self
|
||||
|
||||
abstract_token: AbstractToken = AbstractToken()
|
||||
@ -1954,7 +1976,7 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
except JaxprTypeError as e:
|
||||
msg, = e.args
|
||||
src = source_info_util.summarize(eqn.source_info)
|
||||
msg = "\n\n".join([msg, "in equation:", str(pp_eqn(eqn).indent(2)),
|
||||
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn))),
|
||||
f"from source: {src}"])
|
||||
raise JaxprTypeError(msg, eqn_idx) from None
|
||||
|
||||
@ -2028,81 +2050,104 @@ def check_map(prim, in_avals, params):
|
||||
|
||||
|
||||
# ------------------- Jaxpr printed representation -------------------
|
||||
|
||||
def pp_vars(vs: Sequence[Any], print_shapes: bool = False) -> str:
|
||||
def pp_vars(vs: Sequence[Any], *, print_shapes: bool = False) -> pp.Doc:
|
||||
if print_shapes:
|
||||
return ' '.join(f'{v}:{v.aval.str_short()}' for v in vs)
|
||||
return pp.nest(2, pp.group(
|
||||
pp.join(pp.brk(), [
|
||||
pp.text(str(v)) +
|
||||
pp.dim(pp.text(":" + v.aval.str_short(short_dtypes=True)))
|
||||
for v in vs
|
||||
])
|
||||
))
|
||||
else:
|
||||
return ' '.join(map(str, vs))
|
||||
return pp.nest(2, pp.group(
|
||||
pp.join(pp.brk(), [pp.text(str(v)) for v in vs])
|
||||
))
|
||||
|
||||
def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
|
||||
def pp_kv_pair(k:str, v: Any) -> pp.Doc:
|
||||
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
|
||||
pp_v = pp_jaxprs(v)
|
||||
elif isinstance(v, Jaxpr):
|
||||
pp_v = pp_jaxpr(v)
|
||||
elif isinstance(v, ClosedJaxpr):
|
||||
pp_v = pp_jaxpr(v.jaxpr)
|
||||
else:
|
||||
pp_v = pp.text(str(v))
|
||||
return pp.text(f'{k}=') + pp_v
|
||||
|
||||
def pp_kv_pairs(kv_pairs) -> pp.Doc:
|
||||
if not kv_pairs:
|
||||
return pp.nil()
|
||||
return pp.group(
|
||||
pp.nest(2, pp.concat([
|
||||
pp.text("["), pp.brk(""),
|
||||
pp.join(pp.brk(), [pp_kv_pair(k, v) for k, v in kv_pairs])
|
||||
]))
|
||||
+ pp.brk("") + pp.text("]")
|
||||
)
|
||||
|
||||
def pp_eqn(eqn, *, print_shapes=True, source_info=False) -> pp.Doc:
|
||||
lhs = pp_vars(eqn.outvars, print_shapes=print_shapes)
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if source_info else None)
|
||||
return pp.concat([
|
||||
lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name),
|
||||
pp_kv_pairs(sorted(eqn.params.items())),
|
||||
pp.text(" ") + pp_vars(eqn.invars)
|
||||
])
|
||||
|
||||
|
||||
def pp_eqns(eqns, *, print_shapes=True, source_info=False) -> pp.Doc:
|
||||
return pp.join(
|
||||
pp.brk("; "),
|
||||
map(partial(pp_eqn, print_shapes=print_shapes, source_info=source_info),
|
||||
eqns))
|
||||
|
||||
def pp_eqn_compact(primitive_name: str, params: Dict) -> pp.Doc:
|
||||
filtered_params = {k: v for k, v in params.items()
|
||||
if (k != 'branches' and
|
||||
not isinstance(v, (Jaxpr, ClosedJaxpr)))}
|
||||
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
|
||||
return pp.text(primitive_name) + pp_kv_pairs(sorted(filtered_params.items()))
|
||||
|
||||
def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint:
|
||||
lhs = pp_vars(eqn.outvars, print_shapes)
|
||||
pp_lhs = pp(f'{lhs} =')
|
||||
pp_rhs = (pp(eqn.primitive.name) >>
|
||||
pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
|
||||
pp(pp_vars(eqn.invars, print_shapes)))
|
||||
if len(lhs) <= 6 or print_shapes:
|
||||
return pp_lhs >> pp(' ') >> pp_rhs
|
||||
else:
|
||||
return pp_lhs + pp_rhs.indent(2)
|
||||
|
||||
def pp_eqns(eqns: Sequence[JaxprEqn],
|
||||
source_info: bool = False) -> Sequence[PrettyPrint]:
|
||||
pps = map(pp_eqn, eqns)
|
||||
if source_info:
|
||||
l = max((i + len(s) for x in pps for i, s in x.lines), default=None)
|
||||
if l is not None:
|
||||
return [p.annotate(l, source_info_util.summarize(e.source_info))
|
||||
for e, p in zip(eqns, pps)]
|
||||
return pps
|
||||
|
||||
def pp_jaxpr(jaxpr: Jaxpr, source_info: bool = False) -> PrettyPrint:
|
||||
pps = pp_eqns(jaxpr.eqns, source_info=source_info)
|
||||
def pp_jaxpr_skeleton(jaxpr, eqns_pp, *, print_shapes=True) -> pp.Doc:
|
||||
str_outvars = str(tuple(jaxpr.outvars))
|
||||
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
|
||||
pp_vars(jaxpr.invars))) +
|
||||
((pp('let ') >> vcat(pps))
|
||||
+ pp('in {} }}'.format(str_outvars))).indent(2))
|
||||
return pp.group(pp.nest(2, pp.concat([
|
||||
pp.text("{ "), pp.bright(pp.text("lambda ")),
|
||||
pp_vars(jaxpr.constvars, print_shapes=print_shapes),
|
||||
pp.text("; "), pp_vars(jaxpr.invars, print_shapes=print_shapes),
|
||||
pp.text(". "), pp.bright(pp.text("let")),
|
||||
pp.nest(2, pp.brk() + eqns_pp), pp.brk(),
|
||||
pp.bright(pp.text("in")),
|
||||
pp.text(f" {str_outvars}")
|
||||
])) + pp.text(" }"))
|
||||
|
||||
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int,
|
||||
source_info: bool = False) -> PrettyPrint:
|
||||
|
||||
def pp_jaxpr(jaxpr, *, print_shapes=True, source_info=False) -> pp.Doc:
|
||||
pps = pp_eqns(jaxpr.eqns, print_shapes=print_shapes, source_info=source_info)
|
||||
return pp_jaxpr_skeleton(jaxpr, pps, print_shapes=print_shapes)
|
||||
|
||||
def pp_jaxprs(jaxprs) -> pp.Doc:
|
||||
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
|
||||
return pp.group(pp.nest(2, pp.concat([
|
||||
pp.text('('), pp.brk(""), pp.join(pp.brk(), map(pp_jaxpr, jaxprs))]))
|
||||
+ pp.brk("") + pp.text(')')
|
||||
)
|
||||
|
||||
|
||||
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, print_shapes=True,
|
||||
source_info: bool = False) -> pp.Doc:
|
||||
lo = max(lo, 0)
|
||||
hi = max(lo, min(hi, len(jaxpr.eqns)))
|
||||
eqns = jaxpr.eqns[lo:hi]
|
||||
pps = []
|
||||
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
|
||||
pps.append(pp('...'))
|
||||
pps.append(pp.text('...'))
|
||||
else:
|
||||
if lo != 0:
|
||||
pps.append(pp('...'))
|
||||
pps.extend(pp_eqns(eqns, source_info=source_info))
|
||||
pps.append(pp.text('...'))
|
||||
pps.extend(map(partial(pp_eqn, print_shapes=print_shapes,
|
||||
source_info=source_info), eqns))
|
||||
if hi != len(jaxpr.eqns):
|
||||
pps.append(pp('...'))
|
||||
str_outvars = str(tuple(jaxpr.outvars))
|
||||
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
|
||||
pp_vars(jaxpr.invars))) +
|
||||
((pp('let ') >> vcat(pps))
|
||||
+ pp('in {} }}'.format(str_outvars))).indent(2))
|
||||
|
||||
def pp_jaxprs(jaxprs) -> PrettyPrint:
|
||||
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
|
||||
return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )')
|
||||
|
||||
def pp_kv_pair(k, v):
|
||||
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
|
||||
pp_v = pp_jaxprs(v)
|
||||
else:
|
||||
pp_v = pp(v)
|
||||
return pp(f'{k}=') >> pp_v
|
||||
|
||||
def pp_kv_pairs(kv_pairs):
|
||||
if kv_pairs:
|
||||
return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]')
|
||||
else:
|
||||
return pp('')
|
||||
pps.append(pp.text('...'))
|
||||
return pp_jaxpr_skeleton(jaxpr, pp.join(pp.brk("; "), pps),
|
||||
print_shapes=print_shapes)
|
||||
|
@ -23,7 +23,7 @@ from jax._src import dtypes
|
||||
from jax.core import Var, Literal, Atom, Tracer
|
||||
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
|
||||
tuple_delete)
|
||||
from jax._src.pprint_util import pp, vcat, PrettyPrint
|
||||
import jax._src.pretty_printer as pp
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -177,7 +177,7 @@ class DJaxpr:
|
||||
def __repr__(self):
|
||||
return str(pp_djaxpr(self))
|
||||
|
||||
def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint:
|
||||
def pp_djaxpr(jaxpr: DJaxpr) -> pp.Doc:
|
||||
eqns = map(pp_eqn, jaxpr.eqns)
|
||||
in_dim_binders = pp_vars(jaxpr.in_dim_binders)
|
||||
in_binders = pp_vars(jaxpr.in_binders)
|
||||
@ -185,21 +185,21 @@ def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint:
|
||||
outs = ', '.join(map(str, jaxpr.outs))
|
||||
out_dim_types = pp_vars(jaxpr.out_dims)
|
||||
outs_type = ', '.join(v.aval.str_short() for v in jaxpr.outs)
|
||||
return (pp(f'{{ lambda {in_dim_binders} ; {in_binders} .')
|
||||
+ (pp('let ') >> vcat(eqns) +
|
||||
pp(f'in ( {out_dims} ; {outs} ) '
|
||||
f': ( {out_dim_types} ; {outs_type} ) }}')).indent(2))
|
||||
return (pp.text(f'{{ lambda {in_dim_binders} ; {in_binders} .')
|
||||
+ (pp.text('let ') + pp.nest(2, pp.brk() + pp.join(pp.brk(), eqns)) +
|
||||
pp.text(f'in ( {out_dims} ; {outs} ) '
|
||||
f': ( {out_dim_types} ; {outs_type} ) }}')))
|
||||
|
||||
def pp_vars(vs: Sequence[Atom]) -> str:
|
||||
return ', '.join(f'{v}:{v.aval.str_short()}' for v in vs)
|
||||
|
||||
def pp_eqn(eqn: core.JaxprEqn) -> PrettyPrint:
|
||||
def pp_eqn(eqn: core.JaxprEqn) -> pp.Doc:
|
||||
lhs = pp_vars(eqn.outvars)
|
||||
pp_lhs = pp(f'{lhs} =')
|
||||
pp_rhs = (pp(eqn.primitive.name) >>
|
||||
core.pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
|
||||
pp(' '.join(map(str, eqn.invars))))
|
||||
return pp_lhs >> pp(' ') >> pp_rhs
|
||||
pp_lhs = pp.text(f'{lhs} =')
|
||||
pp_rhs = (pp.text(eqn.primitive.name) +
|
||||
core.pp_kv_pairs(sorted(eqn.params.items())) + pp.text(' ') +
|
||||
pp.text(' '.join(map(str, eqn.invars))))
|
||||
return pp_lhs + pp.text(' ') + pp_rhs
|
||||
|
||||
# Typechecking DJaxprs
|
||||
|
||||
|
@ -458,7 +458,7 @@ from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad, xla, batching, masking, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import pprint_util as ppu
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lib import pytree
|
||||
@ -747,21 +747,31 @@ def _print_tap_func(
|
||||
f"{k}: {v}" for k, v in sorted(kwargs.items())
|
||||
])
|
||||
|
||||
def pp_val(arg) -> ppu.PrettyPrint:
|
||||
def pp_val(arg) -> pp.Doc:
|
||||
if isinstance(arg, tuple):
|
||||
return (
|
||||
ppu.pp("( ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" )"))
|
||||
return pp.group(pp.concat([
|
||||
pp.text("( "),
|
||||
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
|
||||
pp.text(" )")
|
||||
]))
|
||||
elif isinstance(arg, list):
|
||||
return (
|
||||
ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]"))
|
||||
return pp.group(pp.concat([
|
||||
pp.text("[ "),
|
||||
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
|
||||
pp.text(" ]")
|
||||
]))
|
||||
elif isinstance(arg, dict):
|
||||
return (ppu.pp("{ ") >> ppu.vcat([
|
||||
ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items())
|
||||
]) >> ppu.pp(" }"))
|
||||
return pp.group(pp.concat([
|
||||
pp.text("{ "),
|
||||
pp.nest(2, pp.join(pp.brk(), [
|
||||
pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items())
|
||||
])),
|
||||
pp.text(" }")
|
||||
]))
|
||||
elif isinstance(arg, np.ndarray):
|
||||
return ppu.pp(np.array2string(arg, threshold=threshold))
|
||||
return pp.text(np.array2string(arg, threshold=threshold))
|
||||
else:
|
||||
return ppu.pp(str(arg))
|
||||
return pp.text(str(arg))
|
||||
|
||||
with _print_tap_lock:
|
||||
if kv_pairs:
|
||||
|
@ -36,7 +36,7 @@ from jax._src.abstract_arrays import (make_shaped_array, array_types)
|
||||
from ..core import (ConcreteArray, ShapedArray, AbstractToken,
|
||||
Literal, pp_eqn_compact, raise_to_shaped, abstract_token)
|
||||
from ..errors import UnexpectedTracerError
|
||||
from jax._src.pprint_util import pp
|
||||
import jax._src.pretty_printer as pp
|
||||
from .._src.util import (partialmethod, cache, prod, unzip2,
|
||||
extend_name_stack, wrap_name, safe_zip, safe_map,
|
||||
partition_list)
|
||||
@ -112,11 +112,12 @@ def make_op_metadata(primitive: core.Primitive,
|
||||
name_stack: str = "",
|
||||
source_info: Optional[source_info_util.Traceback] = None
|
||||
) -> xc.OpMetadata:
|
||||
tracebacks[str(pp(name_stack) >> pp_eqn_compact(primitive.name, params))] = source_info
|
||||
eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params))
|
||||
tracebacks[eqn_str] = source_info
|
||||
frame = source_info_util.user_frame(source_info) if source_info else None
|
||||
return xc.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
op_name=str(pp(name_stack) >> pp_eqn_compact(primitive.name, params)),
|
||||
op_name=eqn_str,
|
||||
source_file=_get_canonical_source_file(frame) if frame else None,
|
||||
source_line=frame.line_num if frame else None)
|
||||
|
||||
|
2
mypy.ini
2
mypy.ini
@ -4,6 +4,8 @@ disable_error_code = attr-defined
|
||||
|
||||
[mypy-absl.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-colorama.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-numpy.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-opt_einsum.*]
|
||||
|
@ -3722,15 +3722,10 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_const(self):
|
||||
def fun(x):
|
||||
return (x, 1., np.zeros(1))
|
||||
return (x, 1., np.zeros(1, dtype=jnp.float32))
|
||||
|
||||
expected = """
|
||||
{ lambda a ; b.
|
||||
let
|
||||
in (b, 1.0, a) }
|
||||
"""
|
||||
|
||||
jaxpr = api.make_jaxpr(fun)(0.)
|
||||
expected = "{ lambda a:f32[1]; b:f32[]. let in (b, 1.0, a) }"
|
||||
jaxpr = api.make_jaxpr(fun)(jnp.float32(0.))
|
||||
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
||||
|
||||
def test_cond(self):
|
||||
@ -3740,23 +3735,24 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
lambda xt: xt + x,
|
||||
x + 2.,
|
||||
lambda xf: xf - x)
|
||||
expected = """
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = add a 1.0
|
||||
d = add a 2.0
|
||||
e = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] b
|
||||
f = cond[ branches=( { lambda ; e_ a b c.
|
||||
let d = sub c a
|
||||
in (d,) }
|
||||
{ lambda ; a f_ b c.
|
||||
let d = add b a
|
||||
in (d,) } )
|
||||
linear=(False, False, False, False) ] e a a c d
|
||||
in (f,) }
|
||||
"""
|
||||
jaxpr = api.make_jaxpr(f)(3.)
|
||||
expected = """{ lambda ; a:f32[]. let
|
||||
b:bool[] = ge a 0.0
|
||||
c:f32[] = add a 1.0
|
||||
d:f32[] = add a 2.0
|
||||
e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
|
||||
f:f32[] = cond[
|
||||
branches=(
|
||||
{ lambda ; e_:f32[] a:f32[] b:f32[] c:f32[]. let
|
||||
d:f32[] = sub c a
|
||||
in (d,) }
|
||||
{ lambda ; a:f32[] f_:f32[] b:f32[] c:f32[]. let
|
||||
d:f32[] = add b a
|
||||
in (d,) }
|
||||
)
|
||||
linear=(False, False, False, False)
|
||||
] e a a c d
|
||||
in (f,) }"""
|
||||
jaxpr = api.make_jaxpr(f)(jnp.float32(3.))
|
||||
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
||||
|
||||
def test_make_jaxpr_static_argnums(self):
|
||||
|
@ -411,7 +411,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
return jnp.sin(x) + jnp.cos(x)
|
||||
|
||||
def new_jaxpr():
|
||||
return make_jaxpr(f)(1.).jaxpr
|
||||
return make_jaxpr(f)(jnp.float32(1.)).jaxpr
|
||||
|
||||
# jaxpr is:
|
||||
#
|
||||
@ -424,19 +424,21 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
# NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'
|
||||
|
||||
jaxpr = new_jaxpr()
|
||||
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float!
|
||||
# int, not float!
|
||||
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
r"Variable '.' inconsistently typed as ShapedArray(.*), "
|
||||
r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .",
|
||||
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:i32\[\] = sin .",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
jaxpr = new_jaxpr()
|
||||
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
|
||||
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(
|
||||
np.ones((2, 3), dtype=jnp.float32))
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
r"Variable '.' inconsistently typed as ShapedArray(.*), "
|
||||
r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .",
|
||||
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:f32\[2,3\] = sin .",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
def test_jaxpr_dropvar_from_jit_call(self):
|
||||
|
@ -275,8 +275,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( 6.00
|
||||
9.00 )""", testing_stream.output)
|
||||
( 6.00 9.00 )""", testing_stream.output)
|
||||
|
||||
def test_tap_with_dict_results(self):
|
||||
def func2(x):
|
||||
@ -286,8 +285,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ a=6.00
|
||||
b=9.00 }""", testing_stream.output)
|
||||
{ a=6.00 b=9.00 }""", testing_stream.output)
|
||||
|
||||
def test_tap_with_result(self):
|
||||
def func2(x):
|
||||
@ -298,8 +296,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(3. * 4., func2(3.))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( 6.00
|
||||
9.00 )""", testing_stream.output)
|
||||
( 6.00 9.00 )""", testing_stream.output)
|
||||
|
||||
def test_tap_with_result_no_arg(self):
|
||||
def tap_func(arg, transforms):
|
||||
@ -337,8 +334,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
assertMultiDeviceOutputEqual(self, """
|
||||
device: cpu:0
|
||||
( 6.00
|
||||
9.00 )""")
|
||||
( 6.00 9.00 )""")
|
||||
|
||||
def test_tap_eval_exception(self):
|
||||
if not FLAGS.jax_host_callback_outfeed:
|
||||
@ -375,8 +371,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( )
|
||||
what: second
|
||||
( 1.00
|
||||
[] )""", testing_stream.output)
|
||||
( 1.00 [] )""", testing_stream.output)
|
||||
|
||||
def test_tap_jit_simple(self):
|
||||
jit_fun1 = jax.jit(lambda x: 3. * hcb.id_print(
|
||||
@ -906,11 +901,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ['jvp'] what: a * 2
|
||||
( 10.00
|
||||
0.20 )
|
||||
( 10.00 0.20 )
|
||||
transforms: ['jvp'] what: y * 3
|
||||
( 30.00
|
||||
0.60 )""", testing_stream.output)
|
||||
( 30.00 0.60 )""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_primal_unused(self):
|
||||
# The output of id_print is not needed for backwards pass
|
||||
@ -925,23 +918,27 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
|
||||
treedef = tree_util.tree_structure(arg)
|
||||
print(jaxpr)
|
||||
assertMultiLineStrippedEqual(self, f"""
|
||||
{{ lambda ; a.
|
||||
let b = mul a 3.00
|
||||
c = outside_call[ arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=( ) ] b
|
||||
_ = mul c 2.00
|
||||
d = mul 1.00 2.00
|
||||
_ = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=( ) ] 0.00
|
||||
e = outside_call[ arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=(('jvp',), ('transpose',)) ] d
|
||||
f = mul e 3.00
|
||||
in (f,) }}""", jaxpr)
|
||||
{{ lambda ; a:f32[]. let
|
||||
b:f32[] = mul a 3.00
|
||||
c:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=()
|
||||
] b
|
||||
_:* = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
_:* = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
|
||||
e:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=(('jvp',), ('transpose',))
|
||||
] d
|
||||
f:f32[] = mul e 3.00
|
||||
in (f,) }}""", jaxpr)
|
||||
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1016,11 +1013,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 10.00
|
||||
15.00 )
|
||||
( 10.00 15.00 )
|
||||
transforms: ['jvp', 'transpose'] what: pair
|
||||
( 0.00
|
||||
0.00 )""", testing_stream.output)
|
||||
( 0.00 0.00 )""", testing_stream.output)
|
||||
|
||||
def test_tap_jvp_float0(self):
|
||||
def f(x, yint):
|
||||
@ -1042,11 +1037,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 5.00
|
||||
2 )
|
||||
( 5.00 2 )
|
||||
transforms: ['jvp', 'transpose'] what: pair
|
||||
( 2.00
|
||||
False )""", testing_stream.output)
|
||||
( 2.00 False )""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_float0_result(self):
|
||||
# https://github.com/google/jax/issues/7340
|
||||
@ -1068,11 +1061,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dtypes.float0, g[1].dtype)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80]
|
||||
[11 12 13] )
|
||||
( [0.70 0.80] [11 12 13] )
|
||||
transforms: ['jvp', 'transpose']
|
||||
( [0.00 0.00]
|
||||
[False False False] )""", testing_stream.output)
|
||||
( [0.00 0.00] [False False False] )""", testing_stream.output)
|
||||
|
||||
def test_tap_higher_order_grad_float0_result(self):
|
||||
# https://github.com/google/jax/issues/7340
|
||||
@ -1100,8 +1091,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
res = f_jax(x)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80]
|
||||
[11 12 13] )""", testing_stream.output)
|
||||
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# 1st order
|
||||
@ -1109,11 +1099,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
res_vjp1 = f_jax_vjp1(*args_vjp1)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80]
|
||||
[11 12 13] )
|
||||
( [0.70 0.80] [11 12 13] )
|
||||
transforms: ['jvp', 'transpose']
|
||||
( [0.00 0.00]
|
||||
[False False False] )""", testing_stream.output)
|
||||
( [0.00 0.00] [False False False] )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# 2nd order
|
||||
@ -1149,8 +1137,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: [('batch', {'batch_dims': (None, 0)})]
|
||||
( 3.00
|
||||
[4.00 5.00] )""", testing_stream.output)
|
||||
( 3.00 [4.00 5.00] )""", testing_stream.output)
|
||||
|
||||
def test_tap_vmap_vmap(self):
|
||||
# A 2D tensor with x[i, j] = i + j using 2 vmap
|
||||
@ -1249,8 +1236,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.
|
||||
9. )"""
|
||||
( 3. 9. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1258,8 +1244,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [0. 1. 2.]
|
||||
[0. 1. 4.] )"""
|
||||
( [0. 1. 2.] [0. 1. 4.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1267,10 +1252,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
transforms: ['jvp'] what: x,x^2
|
||||
( ( 3.
|
||||
9. )
|
||||
( 0.1
|
||||
0.6 ) )"""
|
||||
( ( 3. 9. ) ( 0.1 0.6 ) )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1278,11 +1260,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3.
|
||||
9. )
|
||||
( 3. 9. )
|
||||
transforms: ['jvp', 'transpose'] what: x,x^2
|
||||
( 0.
|
||||
3. )"""
|
||||
( 0. 3. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -1290,11 +1270,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.]
|
||||
[4. 9.] )
|
||||
( [2. 3.] [4. 9.] )
|
||||
transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
|
||||
( 0.
|
||||
[2. 3.] )"""
|
||||
( 0. [2. 3.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
|
||||
@ -1320,11 +1298,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
assertMultiDeviceOutputEqual(
|
||||
self, """
|
||||
device: cpu:0 what: x,x^2
|
||||
( 3
|
||||
9 )
|
||||
( 3 9 )
|
||||
device: cpu:1 what: x,x^2
|
||||
( 4
|
||||
16 )""")
|
||||
( 4 16 )""")
|
||||
|
||||
def test_tap_pmap_vmap(self):
|
||||
# A matrix M[ij] = i * 10 + j
|
||||
@ -1440,13 +1416,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
assertMultiDeviceOutputEqual(self, """
|
||||
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
||||
( [[ 0.00 2.00 4.00]
|
||||
[20.00 22.00 24.00]]
|
||||
[[0.20 0.20 0.20]
|
||||
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
|
||||
[0.20 0.20 0.20]] )
|
||||
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
||||
( [[200.00 202.00 204.00]
|
||||
[220.00 222.00 224.00]]
|
||||
[[0.20 0.20 0.20]
|
||||
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
|
||||
[0.20 0.20 0.20]] )""")
|
||||
|
||||
def test_tap_vmap_pmap(self):
|
||||
@ -1693,10 +1667,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5})] what: x
|
||||
( ( [0. 1. 2. 3. 4.]
|
||||
[0. 2. 4. 6. 8.] )
|
||||
( ( 3 )
|
||||
( 3 ) ) )""", testing_stream.output)
|
||||
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
|
||||
testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# With VMAP
|
||||
@ -1713,8 +1685,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[5. 6. 7. 8. 9.]]
|
||||
[[ 0. 2. 4. 6. 8.]
|
||||
[10. 12. 14. 16. 18.]] )
|
||||
( ( [3. 4.] )
|
||||
( [3. 4.] ) ) )""", testing_stream.output)
|
||||
( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# With JVP
|
||||
@ -1724,14 +1695,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
|
||||
( ( ( [0. 1. 2. 3. 4.]
|
||||
[0. 2. 4. 6. 8.] )
|
||||
( ( 3 )
|
||||
( 3 ) ) )
|
||||
( ( [0. 0.1 0.2 0.3 0.4]
|
||||
[0. 0.2 0.4 0.6 0.8] )
|
||||
( ( False )
|
||||
( False ) ) ) )""", testing_stream.output)
|
||||
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
|
||||
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
|
||||
testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# Now with JIT
|
||||
@ -1739,10 +1705,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5})] what: x
|
||||
( ( [0. 1. 2. 3. 4.]
|
||||
[0. 2. 4. 6. 8.] )
|
||||
( ( 3 )
|
||||
( 3 ) ) )""", testing_stream.output)
|
||||
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
|
||||
testing_stream.output)
|
||||
|
||||
def test_tap_callback_delay(self):
|
||||
hcb.callback_extra = lambda dev: time.sleep(1)
|
||||
|
@ -80,7 +80,7 @@ class MetadataTest(jtu.JaxTestCase):
|
||||
return jax.lax.cond(True, x, true_fun, x, false_fun)
|
||||
hlo = jax.xla_computation(f)(1.).get_hlo_module().to_string()
|
||||
self.assertRegex(hlo, 'op_type="cond"')
|
||||
self.assertRegex(hlo, 'op_name=".*cond\\[ linear=\\(False, False\\) \\]"')
|
||||
self.assertRegex(hlo, 'op_name=".*cond\\[linear=\\(False, False\\)\\]"')
|
||||
self.assertRegex(hlo, 'op_type="cos"')
|
||||
self.assertRegex(hlo, 'op_name=".*cond/branch_0_fun/cos"')
|
||||
self.assertRegex(hlo, 'op_type="sin"')
|
||||
|
Loading…
x
Reference in New Issue
Block a user