Merge pull request #12485 from LenaMartens:checkify-lower

PiperOrigin-RevId: 476922387
This commit is contained in:
jax authors 2022-09-26 09:53:40 -07:00
commit 7962b01f5d
3 changed files with 71 additions and 9 deletions

View File

@ -489,17 +489,47 @@ CheckEffect = object()
def assert_abstract_eval(err, code, payload, *, msgs):
return [], {CheckEffect}
def assert_lowering_rule(*a, **k):
# TODO(lenamartens): actually throw an error through emit_python_callable
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
raise ValueError('Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)
mlir.register_lowering(assert_p, assert_lowering_rule)
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
'Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)
def python_err(msgs, err, code, payload):
error = Error(err, code, msgs, payload)
check_error(error)
return []
def assert_lowering_rule(ctx, err, code, payload, *, msgs):
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
out_op, token_out, keep_alive = mlir.emit_python_callback(
ctx, callback=lambda *a: python_err(msgs, *a),
token=ctx.tokens_in.get(CheckEffect)[0],
operands=[err, code, payload],
operand_avals=list(ctx.avals_in),
result_avals=list(ctx.avals_out),
has_side_effect=True)
ctx.set_tokens_out(ctx.tokens_in.update_tokens(
mlir.TokenSet({CheckEffect: token_out})))
ctx.module_context.add_keepalive(keep_alive)
return out_op
def assert_lowering_rule_unsupported(*a, **k):
raise functionalization_error
mlir.register_lowering(assert_p, assert_lowering_rule_unsupported,
platform='tpu')
mlir.register_lowering(assert_p, assert_lowering_rule,
platform='cpu')
mlir.register_lowering(assert_p, assert_lowering_rule,
platform='gpu')
mlir.lowerable_effects.add(CheckEffect)
cf.allowed_effects.add(CheckEffect)
core.ordered_effects.add(CheckEffect)
def assert_batching_rule(batched_args, batch_dims, *, msgs):

View File

@ -909,6 +909,16 @@ config.define_bool_state(
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
config.define_bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
'on CPU and GPU. These errors are async, might get lost and are not '
'very readable. But, they crash the computation and enable you '
'to write jittable checks without needing to checkify. Does not '
'work under pmap/pjit.')
)
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""

View File

@ -22,6 +22,7 @@ import numpy as np
import jax
from jax import lax
import jax._src.test_util as jtu
from jax._src.lib import xla_extension
from jax.config import config
from jax.experimental import checkify
from jax.experimental import pjit
@ -905,5 +906,26 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "value needs to be less than 6")
class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.prev = config.jax_experimental_unsafe_xla_runtime_errors
config.update("jax_experimental_unsafe_xla_runtime_errors", True)
def tearDown(self):
config.update("jax_experimental_unsafe_xla_runtime_errors", self.prev)
super().tearDown()
@jtu.skip_on_devices("tpu")
def test_jit(self):
@jax.jit
def f(x):
checkify.check(x > 0, "x needs to be positive")
return x
with self.assertRaisesRegex(xla_extension.XlaRuntimeError,
"x needs to be positive"):
f(-1.)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())