custom pp_eqn rules, simpler xla_call print

This commit is contained in:
Matthew Johnson 2021-11-23 15:51:49 -08:00
parent 28b3c46b9b
commit 8430deda3e
3 changed files with 23 additions and 10 deletions

View File

@ -415,7 +415,6 @@ which the computation should run. For example
{ lambda ; a:f32[]. let
b:f32[] = sub a 2.0
c:f32[1] = xla_call[
backend=None
call_jaxpr={ lambda ; d:f32[] e:f32[]. let
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
@ -423,9 +422,6 @@ which the computation should run. For example
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
j:f32[1] = add i h
in (j,) }
device=None
donated_invars=(False, False)
inline=False
name=inner
] a b
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a

View File

@ -2122,12 +2122,16 @@ def pp_eqn(eqn, context: JaxprPpContext, *, print_shapes=True, source_info=False
lhs = pp_vars(eqn.outvars, context, print_shapes=print_shapes)
annotation = (source_info_util.summarize(eqn.source_info)
if source_info else None)
return pp.concat([
lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name),
pp_kv_pairs(sorted(eqn.params.items()), context),
pp.text(" ") + pp_vars(eqn.invars, context)
])
rule = pp_eqn_rules.get(eqn.primitive)
if rule:
rhs = rule(eqn, context)
else:
rhs = [pp.text(eqn.primitive.name),
pp_kv_pairs(sorted(eqn.params.items()), context),
pp.text(" ") + pp_vars(eqn.invars, context)]
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext], Sequence[pp.Doc]]
pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {}
def pp_eqns(eqns, context: JaxprPpContext, *, print_shapes=True, source_info=False
) -> pp.Doc:

View File

@ -717,6 +717,19 @@ pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext
) -> List[pp.Doc]:
printed_params = {k:v for k, v in eqn.params.items() if
k == 'call_jaxpr' or k == 'name' or
k == 'backend' and v is not None or
k == 'device' and v is not None or
k == 'donated_invars' and any(v)}
return [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
### translation tables
MYPY = False