mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Better pprint rule for check_p primitive.
PiperOrigin-RevId: 539703344
This commit is contained in:
parent
ed073aa6c9
commit
55da62ff75
@ -459,6 +459,21 @@ def _reduce_any_error(error: Error):
|
||||
check_p = core.Primitive('check')
|
||||
check_p.multiple_results = True # zero results
|
||||
|
||||
|
||||
def _pp_check(eqn, context, settings) -> core.pp.Doc:
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if settings.source_info else None)
|
||||
name_stack_annotation = (f'[{eqn.source_info.name_stack}]'
|
||||
if settings.name_stack else None)
|
||||
trimmed_params = sorted((k, v) for (k, v) in eqn.params.items()
|
||||
if k != "err_tree")
|
||||
rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation),
|
||||
core.pp_kv_pairs(trimmed_params, context, settings),
|
||||
core.pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
return core.pp.concat([core.pp.text("", annotation), *rhs])
|
||||
|
||||
core.pp_eqn_rules[check_p] = _pp_check
|
||||
|
||||
# TODO(lenamartens): inherit from Exception instead of ValueError.
|
||||
class JaxRuntimeError(ValueError):
|
||||
pass
|
||||
|
@ -1295,6 +1295,10 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
||||
_ = jax.jit(f, static_argnums=(0,))(True)
|
||||
|
||||
def test_check_pp_rule(self):
|
||||
jaxpr = jax.make_jaxpr(lambda: checkify.check(False, "hi"))()
|
||||
jaxpr.pretty_print(source_info=True, name_stack=True) # Does not crash.
|
||||
|
||||
|
||||
class LowerableChecksTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user