mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
custom pp_eqn rules, simpler xla_call print
This commit is contained in:
parent
28b3c46b9b
commit
8430deda3e
@ -415,7 +415,6 @@ which the computation should run. For example
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = sub a 2.0
|
||||
c:f32[1] = xla_call[
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d:f32[] e:f32[]. let
|
||||
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
|
||||
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
|
||||
@ -423,9 +422,6 @@ which the computation should run. For example
|
||||
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
|
||||
j:f32[1] = add i h
|
||||
in (j,) }
|
||||
device=None
|
||||
donated_invars=(False, False)
|
||||
inline=False
|
||||
name=inner
|
||||
] a b
|
||||
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
|
||||
|
16
jax/core.py
16
jax/core.py
@ -2122,12 +2122,16 @@ def pp_eqn(eqn, context: JaxprPpContext, *, print_shapes=True, source_info=False
|
||||
lhs = pp_vars(eqn.outvars, context, print_shapes=print_shapes)
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if source_info else None)
|
||||
return pp.concat([
|
||||
lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name),
|
||||
pp_kv_pairs(sorted(eqn.params.items()), context),
|
||||
pp.text(" ") + pp_vars(eqn.invars, context)
|
||||
])
|
||||
|
||||
rule = pp_eqn_rules.get(eqn.primitive)
|
||||
if rule:
|
||||
rhs = rule(eqn, context)
|
||||
else:
|
||||
rhs = [pp.text(eqn.primitive.name),
|
||||
pp_kv_pairs(sorted(eqn.params.items()), context),
|
||||
pp.text(" ") + pp_vars(eqn.invars, context)]
|
||||
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
|
||||
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext], Sequence[pp.Doc]]
|
||||
pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {}
|
||||
|
||||
def pp_eqns(eqns, context: JaxprPpContext, *, print_shapes=True, source_info=False
|
||||
) -> pp.Doc:
|
||||
|
@ -717,6 +717,19 @@ pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
|
||||
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
|
||||
|
||||
|
||||
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext
|
||||
) -> List[pp.Doc]:
|
||||
printed_params = {k:v for k, v in eqn.params.items() if
|
||||
k == 'call_jaxpr' or k == 'name' or
|
||||
k == 'backend' and v is not None or
|
||||
k == 'device' and v is not None or
|
||||
k == 'donated_invars' and any(v)}
|
||||
return [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
|
||||
|
||||
|
||||
### translation tables
|
||||
|
||||
MYPY = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user