mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12485 from LenaMartens:checkify-lower
PiperOrigin-RevId: 476922387
This commit is contained in:
commit
7962b01f5d
@ -489,17 +489,47 @@ CheckEffect = object()
|
||||
def assert_abstract_eval(err, code, payload, *, msgs):
|
||||
return [], {CheckEffect}
|
||||
|
||||
def assert_lowering_rule(*a, **k):
|
||||
# TODO(lenamartens): actually throw an error through emit_python_callable
|
||||
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
||||
raise ValueError('Cannot abstractly evaluate a checkify.check which was not'
|
||||
' functionalized. This probably means you tried to stage'
|
||||
' (jit/scan/pmap/...) a `check` without functionalizing it'
|
||||
' through `checkify.checkify`.'
|
||||
)
|
||||
mlir.register_lowering(assert_p, assert_lowering_rule)
|
||||
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
||||
functionalization_error = ValueError(
|
||||
'Cannot abstractly evaluate a checkify.check which was not'
|
||||
' functionalized. This probably means you tried to stage'
|
||||
' (jit/scan/pmap/...) a `check` without functionalizing it'
|
||||
' through `checkify.checkify`.'
|
||||
)
|
||||
|
||||
def python_err(msgs, err, code, payload):
|
||||
error = Error(err, code, msgs, payload)
|
||||
check_error(error)
|
||||
return []
|
||||
|
||||
def assert_lowering_rule(ctx, err, code, payload, *, msgs):
|
||||
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
||||
raise functionalization_error
|
||||
|
||||
out_op, token_out, keep_alive = mlir.emit_python_callback(
|
||||
ctx, callback=lambda *a: python_err(msgs, *a),
|
||||
token=ctx.tokens_in.get(CheckEffect)[0],
|
||||
operands=[err, code, payload],
|
||||
operand_avals=list(ctx.avals_in),
|
||||
result_avals=list(ctx.avals_out),
|
||||
has_side_effect=True)
|
||||
ctx.set_tokens_out(ctx.tokens_in.update_tokens(
|
||||
mlir.TokenSet({CheckEffect: token_out})))
|
||||
ctx.module_context.add_keepalive(keep_alive)
|
||||
return out_op
|
||||
|
||||
def assert_lowering_rule_unsupported(*a, **k):
|
||||
raise functionalization_error
|
||||
|
||||
mlir.register_lowering(assert_p, assert_lowering_rule_unsupported,
|
||||
platform='tpu')
|
||||
mlir.register_lowering(assert_p, assert_lowering_rule,
|
||||
platform='cpu')
|
||||
mlir.register_lowering(assert_p, assert_lowering_rule,
|
||||
platform='gpu')
|
||||
mlir.lowerable_effects.add(CheckEffect)
|
||||
cf.allowed_effects.add(CheckEffect)
|
||||
core.ordered_effects.add(CheckEffect)
|
||||
|
||||
|
||||
def assert_batching_rule(batched_args, batch_dims, *, msgs):
|
||||
|
@ -909,6 +909,16 @@ config.define_bool_state(
|
||||
upgrade=True,
|
||||
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
||||
|
||||
config.define_bool_state(
|
||||
name='jax_experimental_unsafe_xla_runtime_errors',
|
||||
default=False,
|
||||
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
|
||||
'on CPU and GPU. These errors are async, might get lost and are not '
|
||||
'very readable. But, they crash the computation and enable you '
|
||||
'to write jittable checks without needing to checkify. Does not '
|
||||
'work under pmap/pjit.')
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src.lib import xla_extension
|
||||
from jax.config import config
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
@ -905,5 +906,26 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "value needs to be less than 6")
|
||||
|
||||
class LowerableChecksTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.prev = config.jax_experimental_unsafe_xla_runtime_errors
|
||||
config.update("jax_experimental_unsafe_xla_runtime_errors", True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_experimental_unsafe_xla_runtime_errors", self.prev)
|
||||
super().tearDown()
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_jit(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
checkify.check(x > 0, "x needs to be positive")
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(xla_extension.XlaRuntimeError,
|
||||
"x needs to be positive"):
|
||||
f(-1.)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user