mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Use the same name for aliased Vars when pretty-printing Jaxprs.
Add a mechanism for using the same Var names for Vars that are aliased. In this PR, we use this for `pjit`, such that the following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))` prints: ``` { lambda ; a:f32[]. let b:f32[] = pjit[ name=<lambda> jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) } ] a in (b,) } ``` instead of the previous: ``` { lambda ; a:f32[]. let b:f32[] = pjit[ name=<lambda> jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) } ] a in (b,) } ``` The same mechanism could be used for other higher-order primitives, e.g., cond, and others. Also add some typing declarations and rename APIs to use "shared jaxpr" in lieu of "top-level jaxpr" for those Jaxprs that are used multiple times and are printed first. I presume that the term "top-level jaxpr" was picked because these are printed first at top-level. But this is confusing, because they are really subjaxprs. In fact, there was already a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr, and there was also `core.pp_top_level_jaxpr` (which now is named `core.pp_shared_jaxpr`.
This commit is contained in:
parent
eee4d6019b
commit
a6c47d6f36
@ -2955,8 +2955,11 @@ def _check_map(ctx_factory, prim, in_avals, params):
|
||||
|
||||
# ------------------- Jaxpr printed representation -------------------
|
||||
|
||||
def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
|
||||
custom_pp_eqn_rules=True, name_stack=False,
|
||||
def pp_toplevel_jaxpr(jaxpr_to_print: Jaxpr, *,
|
||||
source_info: bool = False,
|
||||
print_shapes: bool = True,
|
||||
custom_pp_eqn_rules : bool = True,
|
||||
name_stack: bool = False,
|
||||
print_effects: bool = False) -> pp.Doc:
|
||||
context = JaxprPpContext()
|
||||
settings = JaxprPpSettings(
|
||||
@ -2998,9 +3001,9 @@ def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
|
||||
name_counts[name] += 1
|
||||
else:
|
||||
name_counts[name] += 1
|
||||
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
|
||||
context.used_names.add(name)
|
||||
context.top_level_jaxprs[jaxpr] = name
|
||||
docs.append(pp_shared_jaxpr(name, jaxpr, context, settings))
|
||||
context.shared_jaxpr_names.add(name)
|
||||
context.shared_jaxprs[jaxpr] = name
|
||||
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
|
||||
return pp.concat(docs)
|
||||
|
||||
@ -3025,19 +3028,41 @@ def _encode_digits_alphabetic(n: int) -> str:
|
||||
# Jaxprs.
|
||||
class JaxprPpContext:
|
||||
var_names: defaultdict[Var, str]
|
||||
used_names: MutableSet[str]
|
||||
top_level_jaxprs: MutableMapping[Jaxpr, str]
|
||||
# Shared jaxprs are those that are used multiple times and are printed
|
||||
# first.
|
||||
shared_jaxprs: MutableMapping[Jaxpr, str] # maps shared jaxpr to its name
|
||||
shared_jaxpr_names: MutableSet[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.top_level_jaxprs = {}
|
||||
self.used_names = set()
|
||||
self.shared_jaxprs = {}
|
||||
self.shared_jaxpr_names = set()
|
||||
fresh_names: Iterator[str] = (
|
||||
name
|
||||
for i in it.count()
|
||||
if (name := _encode_digits_alphabetic(i)) not in self.used_names
|
||||
if (name := _encode_digits_alphabetic(i)) not in self.shared_jaxpr_names
|
||||
)
|
||||
self.var_names = defaultdict(fresh_names.__next__)
|
||||
|
||||
def suggest_same_var_names(self,
|
||||
for_vars: Sequence[Atom],
|
||||
like_vars: Sequence[Atom]) -> None:
|
||||
"""Suggests the names for `for_vars` to match those of `like_vars`.
|
||||
|
||||
`for_vars` are distinct Vars, and are aliased with `like_vars`.
|
||||
"""
|
||||
used_like_vars: set[Var] = set()
|
||||
if len(for_vars) != len(like_vars):
|
||||
# The mismatch can happen if a primitive containing a subjaxpr is invoked
|
||||
# with the wrong number of arguments, e.g., when printing an invalid Jaxpr.
|
||||
return
|
||||
for for_v, like_v in zip(for_vars, like_vars):
|
||||
if (isinstance(like_v, Var) and
|
||||
like_v not in used_like_vars and
|
||||
isinstance(for_v, Var) and
|
||||
for_v not in self.var_names):
|
||||
used_like_vars.add(like_v)
|
||||
self.var_names[for_v] = pp_var(like_v, self)
|
||||
|
||||
|
||||
def pp_var(v: Var | Literal, context: JaxprPpContext) -> str:
|
||||
if isinstance(v, (Literal, DropVar)): return str(v)
|
||||
@ -3051,7 +3076,7 @@ def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
|
||||
else:
|
||||
return a.str_short(short_dtypes=True)
|
||||
|
||||
def pp_vars(vs: Sequence[Any], context: JaxprPpContext,
|
||||
def pp_vars(vs: Sequence[Atom], context: JaxprPpContext,
|
||||
*, separator="", print_shapes: bool = False) -> pp.Doc:
|
||||
if print_shapes:
|
||||
return pp.nest(2, pp.group(
|
||||
@ -3097,7 +3122,8 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
|
||||
user_frame = source_info_util.user_frame(eqn.source_info)
|
||||
return doc if user_frame is None else pp.source_map(doc, user_frame)
|
||||
|
||||
def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
|
||||
def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings,
|
||||
params: Sequence[str] | None = None) -> pp.Doc:
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if settings.source_info else None)
|
||||
if params is None:
|
||||
@ -3112,9 +3138,10 @@ def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
|
||||
else:
|
||||
return pp.concat(rhs)
|
||||
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
|
||||
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
|
||||
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
|
||||
|
||||
def pp_eqns(eqns, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
||||
def pp_eqns(eqns: Sequence[JaxprEqn],
|
||||
context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
||||
return pp.join(
|
||||
pp.brk("; "),
|
||||
[pp_eqn(e, context, settings) for e in eqns])
|
||||
@ -3139,7 +3166,7 @@ custom_str_eqn_compact_rules: dict[
|
||||
Primitive, Callable[[Primitive, dict[Any, Any]], str]
|
||||
] = {}
|
||||
|
||||
def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
|
||||
def pp_jaxpr_skeleton(jaxpr: Jaxpr, eqns_fn, context: JaxprPpContext,
|
||||
settings: JaxprPpSettings) -> pp.Doc:
|
||||
constvars = pp_vars(jaxpr.constvars, context, print_shapes=settings.print_shapes)
|
||||
invars = pp_vars(jaxpr.invars, context, print_shapes=settings.print_shapes)
|
||||
@ -3173,7 +3200,7 @@ def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
|
||||
])) + pp.text(" }"))
|
||||
|
||||
|
||||
def pp_top_level_jaxpr(
|
||||
def pp_shared_jaxpr(
|
||||
name: str,
|
||||
jaxpr: Jaxpr,
|
||||
context: JaxprPpContext,
|
||||
@ -3192,13 +3219,14 @@ def pp_jaxpr(
|
||||
context: JaxprPpContext,
|
||||
settings: JaxprPpSettings,
|
||||
) -> pp.Doc:
|
||||
if name := context.top_level_jaxprs.get(jaxpr):
|
||||
if name := context.shared_jaxprs.get(jaxpr):
|
||||
return pp.text(name)
|
||||
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, settings)
|
||||
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
|
||||
|
||||
|
||||
def pp_jaxprs(jaxprs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
||||
def pp_jaxprs(jaxprs: Sequence[ClosedJaxpr | Jaxpr],
|
||||
context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
||||
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
|
||||
return pp.group(pp.nest(2, pp.concat([
|
||||
pp.text('('), pp.brk(""),
|
||||
|
@ -4360,8 +4360,6 @@ def _convert_elt_type_fwd_rule(eqn):
|
||||
return [None], eqn
|
||||
|
||||
def _convert_elt_type_pp_rule(eqn, context, settings):
|
||||
# don't print new_dtype because the output binder shows it, don't print
|
||||
# weak_type when false
|
||||
params = dict(eqn.params)
|
||||
if params['sharding'] is None:
|
||||
del params['sharding'] # don't show trivial case
|
||||
|
@ -2447,7 +2447,9 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule
|
||||
|
||||
|
||||
def _pjit_pp_rule(eqn, context, settings):
|
||||
def _pjit_pp_rule(eqn: core.JaxprEqn,
|
||||
context: core.JaxprPpContext,
|
||||
settings: core.JaxprPpSettings) -> core.pp.Doc:
|
||||
params = dict(eqn.params)
|
||||
del params['inline']
|
||||
if not any(params['donated_invars']):
|
||||
@ -2468,6 +2470,10 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
if not params['compiler_options_kvs']:
|
||||
del params['compiler_options_kvs']
|
||||
|
||||
if params['jaxpr'].jaxpr not in context.shared_jaxprs:
|
||||
context.suggest_same_var_names(params['jaxpr'].jaxpr.invars, eqn.invars)
|
||||
context.suggest_same_var_names(params['jaxpr'].jaxpr.outvars, eqn.outvars)
|
||||
|
||||
# Move name= to the front to make the resulting equation easier to scan.
|
||||
del params["name"]
|
||||
return core._pp_eqn(eqn, context, settings, params=["name"] + sorted(params))
|
||||
|
@ -1222,16 +1222,78 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
{ lambda ; c:f32[1]. let
|
||||
d:f32[1] = pjit[
|
||||
name=<lambda>
|
||||
jaxpr={ lambda ; e:f32[1]. let
|
||||
f:f32[1] = pjit[name=<lambda> jaxpr=lambda] e
|
||||
g:f32[1] = pjit[name=<lambda> jaxpr=lambda] e
|
||||
h:f32[1] = add f g
|
||||
in (h,) }
|
||||
jaxpr={ lambda ; c:f32[1]. let
|
||||
e:f32[1] = pjit[name=<lambda> jaxpr=lambda] c
|
||||
f:f32[1] = pjit[name=<lambda> jaxpr=lambda] c
|
||||
d:f32[1] = add e f
|
||||
in (d,) }
|
||||
] c
|
||||
in (d,) }
|
||||
""").strip(),
|
||||
)
|
||||
|
||||
def test_pretty_print_pjit_id(self):
|
||||
f = pjit(lambda x, y: x)
|
||||
x = jnp.array([4.2], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(lambda y: y + f(y, y))(x)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
{ lambda ; a:f32[1]. let
|
||||
pjit[name=<lambda> jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a
|
||||
c:f32[1] = add a a
|
||||
in (c,) }
|
||||
""").strip(),
|
||||
)
|
||||
|
||||
def test_pretty_print_with_constant_pjit_arg(self):
|
||||
f = pjit(lambda x, y: x * y)
|
||||
x = jnp.array([4.2], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(lambda x: f(x, np.float32(1.0)))(x)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
{ lambda ; a:f32[1]. let
|
||||
b:f32[1] = pjit[
|
||||
name=<lambda>
|
||||
jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) }
|
||||
] a 1.0
|
||||
in (b,) }
|
||||
""").strip(),
|
||||
)
|
||||
|
||||
def test_pretty_print_with_aliased_args(self):
|
||||
f = pjit(lambda x, y, z: x * y * z)
|
||||
x = jnp.array([4.2], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(lambda x: f(x, x, x))(x)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
{ lambda ; a:f32[1]. let
|
||||
b:f32[1] = pjit[
|
||||
name=<lambda>
|
||||
jaxpr={ lambda ; a:f32[1] c:f32[1] d:f32[1]. let
|
||||
e:f32[1] = mul a c
|
||||
b:f32[1] = mul e d
|
||||
in (b,) }
|
||||
] a a a
|
||||
in (b,) }
|
||||
""").strip(),
|
||||
)
|
||||
|
||||
def test_pretty_print_with_literal_outvar(self):
|
||||
f = pjit(lambda x: (np.int32(2), x))
|
||||
x = jnp.array([4.2], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(f)(x)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
{ lambda ; a:f32[1]. let
|
||||
b:i32[] = pjit[name=<lambda> jaxpr={ lambda ; a:f32[1]. let in (2,) }] a
|
||||
in (b, a) }
|
||||
""").strip(),
|
||||
)
|
||||
|
||||
def test_pretty_print_with_closure(self):
|
||||
@pjit
|
||||
def g(x, y):
|
||||
@ -1249,11 +1311,11 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
{ lambda ; d:f32[1] e:f32[1]. let
|
||||
g:f32[1] = pjit[
|
||||
name=g
|
||||
jaxpr={ lambda ; h:f32[1] i:f32[1]. let
|
||||
j:f32[1] = pjit[name=f jaxpr=f] i h
|
||||
k:f32[1] = pjit[name=f jaxpr=f] i i
|
||||
l:f32[1] = add j k
|
||||
in (l,) }
|
||||
jaxpr={ lambda ; d:f32[1] e:f32[1]. let
|
||||
h:f32[1] = pjit[name=f jaxpr=f] e d
|
||||
i:f32[1] = pjit[name=f jaxpr=f] e e
|
||||
g:f32[1] = add h i
|
||||
in (g,) }
|
||||
] d e
|
||||
in (g,) }
|
||||
""").strip(),
|
||||
@ -1279,15 +1341,15 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
{ lambda ; c:f32[1] d:f32[2]. let
|
||||
e:f32[2] = pjit[
|
||||
name=g
|
||||
jaxpr={ lambda ; g:f32[1] h:f32[2]. let
|
||||
pjit[name=f jaxpr=f] g
|
||||
pjit[name=f jaxpr=f] g
|
||||
i:f32[1] = mul g g
|
||||
pjit[name=f jaxpr=f1] h
|
||||
pjit[name=f jaxpr=f1] h
|
||||
j:f32[2] = mul h h
|
||||
k:f32[2] = add i j
|
||||
in (k,) }
|
||||
jaxpr={ lambda ; c:f32[1] d:f32[2]. let
|
||||
pjit[name=f jaxpr=f] c
|
||||
pjit[name=f jaxpr=f] c
|
||||
g:f32[1] = mul c c
|
||||
pjit[name=f jaxpr=f1] d
|
||||
pjit[name=f jaxpr=f1] d
|
||||
h:f32[2] = mul d d
|
||||
e:f32[2] = add g h
|
||||
in (e,) }
|
||||
] c d
|
||||
in (e,) }
|
||||
""").strip(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user