From 55da62ff75c9c7354c9522a82791ee47c21def1a Mon Sep 17 00:00:00 2001 From: Lena Martens Date: Mon, 12 Jun 2023 10:58:00 -0700 Subject: [PATCH] Better pprint rule for check_p primitive. PiperOrigin-RevId: 539703344 --- jax/_src/checkify.py | 15 +++++++++++++++ tests/checkify_test.py | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 9c083f69a..518c887de 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -459,6 +459,21 @@ def _reduce_any_error(error: Error): check_p = core.Primitive('check') check_p.multiple_results = True # zero results + +def _pp_check(eqn, context, settings) -> core.pp.Doc: + annotation = (source_info_util.summarize(eqn.source_info) + if settings.source_info else None) + name_stack_annotation = (f'[{eqn.source_info.name_stack}]' + if settings.name_stack else None) + trimmed_params = sorted((k, v) for (k, v) in eqn.params.items() + if k != "err_tree") + rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation), + core.pp_kv_pairs(trimmed_params, context, settings), + core.pp.text(" ") + core.pp_vars(eqn.invars, context)] + return core.pp.concat([core.pp.text("", annotation), *rhs]) + +core.pp_eqn_rules[check_p] = _pp_check + # TODO(lenamartens): inherit from Exception instead of ValueError. class JaxRuntimeError(ValueError): pass diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 14c28c545..7cee2d73b 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -1295,6 +1295,10 @@ class AssertPrimitiveTests(jtu.JaxTestCase): _ = jax.jit(f, static_argnums=(0,))(True) + def test_check_pp_rule(self): + jaxpr = jax.make_jaxpr(lambda: checkify.check(False, "hi"))() + jaxpr.pretty_print(source_info=True, name_stack=True) # Does not crash. + class LowerableChecksTest(jtu.JaxTestCase): def setUp(self):