Better pprint rule for check_p primitive.

PiperOrigin-RevId: 539703344
This commit is contained in:
Lena Martens 2023-06-12 10:58:00 -07:00 committed by jax authors
parent ed073aa6c9
commit 55da62ff75
2 changed files with 19 additions and 0 deletions

View File

@ -459,6 +459,21 @@ def _reduce_any_error(error: Error):
check_p = core.Primitive('check')
check_p.multiple_results = True # zero results
def _pp_check(eqn, context, settings) -> core.pp.Doc:
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
name_stack_annotation = (f'[{eqn.source_info.name_stack}]'
if settings.name_stack else None)
trimmed_params = sorted((k, v) for (k, v) in eqn.params.items()
if k != "err_tree")
rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation),
core.pp_kv_pairs(trimmed_params, context, settings),
core.pp.text(" ") + core.pp_vars(eqn.invars, context)]
return core.pp.concat([core.pp.text("", annotation), *rhs])
core.pp_eqn_rules[check_p] = _pp_check
# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
pass

View File

@ -1295,6 +1295,10 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
_ = jax.jit(f, static_argnums=(0,))(True)
def test_check_pp_rule(self):
jaxpr = jax.make_jaxpr(lambda: checkify.check(False, "hi"))()
jaxpr.pretty_print(source_info=True, name_stack=True) # Does not crash.
class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):