Use the same name for aliased Vars when pretty-printing Jaxprs.

Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
          ] a
    in (b,) }
```

instead of the previous:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
          ] a
    in (b,) }
```

The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.

Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
This commit is contained in:
George Necula 2025-02-24 10:03:23 +02:00
parent eee4d6019b
commit a6c47d6f36
4 changed files with 134 additions and 40 deletions

View File

@ -2955,8 +2955,11 @@ def _check_map(ctx_factory, prim, in_avals, params):
# ------------------- Jaxpr printed representation -------------------
def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
def pp_toplevel_jaxpr(jaxpr_to_print: Jaxpr, *,
source_info: bool = False,
print_shapes: bool = True,
custom_pp_eqn_rules : bool = True,
name_stack: bool = False,
print_effects: bool = False) -> pp.Doc:
context = JaxprPpContext()
settings = JaxprPpSettings(
@ -2998,9 +3001,9 @@ def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
name_counts[name] += 1
else:
name_counts[name] += 1
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
context.used_names.add(name)
context.top_level_jaxprs[jaxpr] = name
docs.append(pp_shared_jaxpr(name, jaxpr, context, settings))
context.shared_jaxpr_names.add(name)
context.shared_jaxprs[jaxpr] = name
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
return pp.concat(docs)
@ -3025,19 +3028,41 @@ def _encode_digits_alphabetic(n: int) -> str:
# Jaxprs.
class JaxprPpContext:
var_names: defaultdict[Var, str]
used_names: MutableSet[str]
top_level_jaxprs: MutableMapping[Jaxpr, str]
# Shared jaxprs are those that are used multiple times and are printed
# first.
shared_jaxprs: MutableMapping[Jaxpr, str] # maps shared jaxpr to its name
shared_jaxpr_names: MutableSet[str]
def __init__(self) -> None:
self.top_level_jaxprs = {}
self.used_names = set()
self.shared_jaxprs = {}
self.shared_jaxpr_names = set()
fresh_names: Iterator[str] = (
name
for i in it.count()
if (name := _encode_digits_alphabetic(i)) not in self.used_names
if (name := _encode_digits_alphabetic(i)) not in self.shared_jaxpr_names
)
self.var_names = defaultdict(fresh_names.__next__)
def suggest_same_var_names(self,
for_vars: Sequence[Atom],
like_vars: Sequence[Atom]) -> None:
"""Suggests the names for `for_vars` to match those of `like_vars`.
`for_vars` are distinct Vars, and are aliased with `like_vars`.
"""
used_like_vars: set[Var] = set()
if len(for_vars) != len(like_vars):
# The mismatch can happen if a primitive containing a subjaxpr is invoked
# with the wrong number of arguments, e.g., when printing an invalid Jaxpr.
return
for for_v, like_v in zip(for_vars, like_vars):
if (isinstance(like_v, Var) and
like_v not in used_like_vars and
isinstance(for_v, Var) and
for_v not in self.var_names):
used_like_vars.add(like_v)
self.var_names[for_v] = pp_var(like_v, self)
def pp_var(v: Var | Literal, context: JaxprPpContext) -> str:
if isinstance(v, (Literal, DropVar)): return str(v)
@ -3051,7 +3076,7 @@ def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
else:
return a.str_short(short_dtypes=True)
def pp_vars(vs: Sequence[Any], context: JaxprPpContext,
def pp_vars(vs: Sequence[Atom], context: JaxprPpContext,
*, separator="", print_shapes: bool = False) -> pp.Doc:
if print_shapes:
return pp.nest(2, pp.group(
@ -3097,7 +3122,8 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
user_frame = source_info_util.user_frame(eqn.source_info)
return doc if user_frame is None else pp.source_map(doc, user_frame)
def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings,
params: Sequence[str] | None = None) -> pp.Doc:
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
if params is None:
@ -3112,9 +3138,10 @@ def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
else:
return pp.concat(rhs)
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
def pp_eqns(eqns, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
def pp_eqns(eqns: Sequence[JaxprEqn],
context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
return pp.join(
pp.brk("; "),
[pp_eqn(e, context, settings) for e in eqns])
@ -3139,7 +3166,7 @@ custom_str_eqn_compact_rules: dict[
Primitive, Callable[[Primitive, dict[Any, Any]], str]
] = {}
def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
def pp_jaxpr_skeleton(jaxpr: Jaxpr, eqns_fn, context: JaxprPpContext,
settings: JaxprPpSettings) -> pp.Doc:
constvars = pp_vars(jaxpr.constvars, context, print_shapes=settings.print_shapes)
invars = pp_vars(jaxpr.invars, context, print_shapes=settings.print_shapes)
@ -3173,7 +3200,7 @@ def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
])) + pp.text(" }"))
def pp_top_level_jaxpr(
def pp_shared_jaxpr(
name: str,
jaxpr: Jaxpr,
context: JaxprPpContext,
@ -3192,13 +3219,14 @@ def pp_jaxpr(
context: JaxprPpContext,
settings: JaxprPpSettings,
) -> pp.Doc:
if name := context.top_level_jaxprs.get(jaxpr):
if name := context.shared_jaxprs.get(jaxpr):
return pp.text(name)
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, settings)
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
def pp_jaxprs(jaxprs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
def pp_jaxprs(jaxprs: Sequence[ClosedJaxpr | Jaxpr],
context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp.group(pp.nest(2, pp.concat([
pp.text('('), pp.brk(""),

View File

@ -4360,8 +4360,6 @@ def _convert_elt_type_fwd_rule(eqn):
return [None], eqn
def _convert_elt_type_pp_rule(eqn, context, settings):
# don't print new_dtype because the output binder shows it, don't print
# weak_type when false
params = dict(eqn.params)
if params['sharding'] is None:
del params['sharding'] # don't show trivial case

View File

@ -2447,7 +2447,9 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule
def _pjit_pp_rule(eqn, context, settings):
def _pjit_pp_rule(eqn: core.JaxprEqn,
context: core.JaxprPpContext,
settings: core.JaxprPpSettings) -> core.pp.Doc:
params = dict(eqn.params)
del params['inline']
if not any(params['donated_invars']):
@ -2468,6 +2470,10 @@ def _pjit_pp_rule(eqn, context, settings):
if not params['compiler_options_kvs']:
del params['compiler_options_kvs']
if params['jaxpr'].jaxpr not in context.shared_jaxprs:
context.suggest_same_var_names(params['jaxpr'].jaxpr.invars, eqn.invars)
context.suggest_same_var_names(params['jaxpr'].jaxpr.outvars, eqn.outvars)
# Move name= to the front to make the resulting equation easier to scan.
del params["name"]
return core._pp_eqn(eqn, context, settings, params=["name"] + sorted(params))

View File

@ -1222,16 +1222,78 @@ class PJitTest(jtu.BufferDonationTestCase):
{ lambda ; c:f32[1]. let
d:f32[1] = pjit[
name=<lambda>
jaxpr={ lambda ; e:f32[1]. let
f:f32[1] = pjit[name=<lambda> jaxpr=lambda] e
g:f32[1] = pjit[name=<lambda> jaxpr=lambda] e
h:f32[1] = add f g
in (h,) }
jaxpr={ lambda ; c:f32[1]. let
e:f32[1] = pjit[name=<lambda> jaxpr=lambda] c
f:f32[1] = pjit[name=<lambda> jaxpr=lambda] c
d:f32[1] = add e f
in (d,) }
] c
in (d,) }
""").strip(),
)
def test_pretty_print_pjit_id(self):
f = pjit(lambda x, y: x)
x = jnp.array([4.2], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(lambda y: y + f(y, y))(x)
self.assertEqual(
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
{ lambda ; a:f32[1]. let
pjit[name=<lambda> jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a
c:f32[1] = add a a
in (c,) }
""").strip(),
)
def test_pretty_print_with_constant_pjit_arg(self):
f = pjit(lambda x, y: x * y)
x = jnp.array([4.2], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(lambda x: f(x, np.float32(1.0)))(x)
self.assertEqual(
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
{ lambda ; a:f32[1]. let
b:f32[1] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) }
] a 1.0
in (b,) }
""").strip(),
)
def test_pretty_print_with_aliased_args(self):
f = pjit(lambda x, y, z: x * y * z)
x = jnp.array([4.2], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(lambda x: f(x, x, x))(x)
self.assertEqual(
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
{ lambda ; a:f32[1]. let
b:f32[1] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[1] c:f32[1] d:f32[1]. let
e:f32[1] = mul a c
b:f32[1] = mul e d
in (b,) }
] a a a
in (b,) }
""").strip(),
)
def test_pretty_print_with_literal_outvar(self):
f = pjit(lambda x: (np.int32(2), x))
x = jnp.array([4.2], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(f)(x)
self.assertEqual(
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
{ lambda ; a:f32[1]. let
b:i32[] = pjit[name=<lambda> jaxpr={ lambda ; a:f32[1]. let in (2,) }] a
in (b, a) }
""").strip(),
)
def test_pretty_print_with_closure(self):
@pjit
def g(x, y):
@ -1249,11 +1311,11 @@ class PJitTest(jtu.BufferDonationTestCase):
{ lambda ; d:f32[1] e:f32[1]. let
g:f32[1] = pjit[
name=g
jaxpr={ lambda ; h:f32[1] i:f32[1]. let
j:f32[1] = pjit[name=f jaxpr=f] i h
k:f32[1] = pjit[name=f jaxpr=f] i i
l:f32[1] = add j k
in (l,) }
jaxpr={ lambda ; d:f32[1] e:f32[1]. let
h:f32[1] = pjit[name=f jaxpr=f] e d
i:f32[1] = pjit[name=f jaxpr=f] e e
g:f32[1] = add h i
in (g,) }
] d e
in (g,) }
""").strip(),
@ -1279,15 +1341,15 @@ class PJitTest(jtu.BufferDonationTestCase):
{ lambda ; c:f32[1] d:f32[2]. let
e:f32[2] = pjit[
name=g
jaxpr={ lambda ; g:f32[1] h:f32[2]. let
pjit[name=f jaxpr=f] g
pjit[name=f jaxpr=f] g
i:f32[1] = mul g g
pjit[name=f jaxpr=f1] h
pjit[name=f jaxpr=f1] h
j:f32[2] = mul h h
k:f32[2] = add i j
in (k,) }
jaxpr={ lambda ; c:f32[1] d:f32[2]. let
pjit[name=f jaxpr=f] c
pjit[name=f jaxpr=f] c
g:f32[1] = mul c c
pjit[name=f jaxpr=f1] d
pjit[name=f jaxpr=f1] d
h:f32[2] = mul d d
e:f32[2] = add g h
in (e,) }
] c d
in (e,) }
""").strip(),