errors: create TracerBoolConversionError for more targeted debugging tips

This commit is contained in:
Jake VanderPlas 2023-06-21 01:41:45 -07:00
parent 06f76bc6bc
commit f1e603e4b3
5 changed files with 123 additions and 59 deletions

View File

@ -9,5 +9,6 @@ along with representative examples of how one might fix them.
.. autoclass:: ConcretizationTypeError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerBoolConversionError
.. autoclass:: TracerIntegerConversionError
.. autoclass:: UnexpectedTracerError

View File

@ -41,7 +41,7 @@ from jax._src import config as jax_config
from jax._src import effects
from jax._src.config import FLAGS, config
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError,
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu
@ -1366,8 +1366,12 @@ def concretization_function_error(fun, suggest_astype=False):
fname_context += ("If trying to convert the data type of a value, "
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
def error(self, arg):
raise ConcretizationTypeError(arg, fname_context)
if fun is bool:
def error(self, arg):
raise TracerBoolConversionError(arg)
else:
def error(self, arg):
raise ConcretizationTypeError(arg, fname_context)
return error
def concrete_or_error(force: Any, val: Any, context=""):

View File

@ -53,33 +53,7 @@ class ConcretizationTypeError(JAXTypeError):
program is doing operations that are not directly supported by JAX's JIT
compilation model.
Using non-JAX aware functions
One common cause of this error is using non-JAX aware functions within JAX
code. For example:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x):
... return min(x, 0)
>>> func(2) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected:
The problem arose with the `bool` function.
In this case, the error occurs because Python's built-in ``min`` function is not
compatible with JAX transforms. This can be fixed by replacing it with
``jnp.minumum``:
>>> @jit
... def func(x):
... return jnp.minimum(x, 0)
>>> print(func(2))
0
Examples:
Traced value where static value is expected
One common cause of this error is using a traced value where a static value
@ -107,34 +81,6 @@ class ConcretizationTypeError(JAXTypeError):
>>> func(jnp.arange(4), 0)
Array(0, dtype=int32)
Traced value used in control flow
Another case where this often arises is when a traced value is used in
Python control flow. For example::
>>> @jit
... def func(x, y):
... return x if x.sum() < y.sum() else y
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: [...]
We could mark both inputs ``x`` and ``y`` as static, but that would defeat
the purpose of using :func:`jax.jit` here. Another option is to re-express
the if statement in terms of :func:`jax.numpy.where`::
>>> @jit
... def func(x, y):
... return jnp.where(x.sum() < y.sum(), x, y)
>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)
For more complicated control flow including loops, see
:ref:`lax-control-flow`.
Shape depends on Traced Value
Such an error may also arise when a shape in your JIT-compiled computation
depends on the values within a traced quantity. For example::
@ -464,6 +410,118 @@ class TracerIntegerConversionError(JAXTypeError):
f"{tracer._origin_msg()}")
@export
class TracerBoolConversionError(ConcretizationTypeError):
"""
This error occurs when a traced value in JAX is used in a context where a
boolean value is expected (see :ref:`faq-different-kinds-of-jax-values`
for more on what a Tracer is).
The boolean cast may be an explicit (e.g. ``bool(x)``) or implicit, through use of
control flow (e.g. ``if x > 0`` or ``while x``), use of Python boolean
operators (e.g. ``z = x and y``, ``z = x or y``, ``z = not x``) or functions
that use them (e.g. ``z = max(x, y)``, ``z = min(x, y)`` etc.).
In some situations, this problem can be easily fixed by marking traced values as
static; in others, it may indicate that your program is doing operations that are
not directly supported by JAX's JIT compilation model.
Examples:
Traced value used in control flow
One case where this often arises is when a traced value is used in
Python control flow. For example::
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, y):
... return x if x.sum() < y.sum() else y
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
We could mark both inputs ``x`` and ``y`` as static, but that would defeat
the purpose of using :func:`jax.jit` here. Another option is to re-express
the if statement in terms of the three-term :func:`jax.numpy.where`::
>>> @jit
... def func(x, y):
... return jnp.where(x.sum() < y.sum(), x, y)
>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)
For more complicated control flow including loops, see
:ref:`lax-control-flow`.
Control flow on traced values
Another common cause of this error is if you inadvertently trace over a boolean
flag. For example::
>>> @jit
... def func(x, normalize=True):
... if normalize:
... return x / x.sum()
... return x
>>> func(jnp.arange(5), True) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
Here because the flag ``normalize`` is traced, it cannot be used in Python
control flow. In this situation, the best solution is probably to mark this
value as static::
>>> from functools import partial
>>> @partial(jit, static_argnames=['normalize'])
... def func(x, normalize=True):
... if normalize:
... return x / x.sum()
... return x
>>> func(jnp.arange(5), True)
Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
For more on ``static_argnums``, see the documentation of :func:`jax.jit`.
Using non-JAX aware functions
Another common cause of this error is using non-JAX aware functions within JAX
code. For example:
>>> @jit
... def func(x):
... return min(x, 0)
>>> func(2) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
In this case, the error occurs because Python's built-in ``min`` function is not
compatible with JAX transforms. This can be fixed by replacing it with
``jnp.minumum``:
>>> @jit
... def func(x):
... return jnp.minimum(x, 0)
>>> print(func(2))
0
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: core.Tracer):
JAXTypeError.__init__(self,
f"Attempted boolean conversion of {tracer._error_repr()}."
f"{tracer._origin_msg()}")
@export
class UnexpectedTracerError(JAXTypeError):
"""

View File

@ -21,6 +21,7 @@ from jax._src.errors import (
ConcretizationTypeError as ConcretizationTypeError,
NonConcreteBooleanIndexError as NonConcreteBooleanIndexError,
TracerArrayConversionError as TracerArrayConversionError,
TracerBoolConversionError as TracerBoolConversionError,
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
)

View File

@ -1473,7 +1473,7 @@ class APITest(jtu.JaxTestCase):
assert grad(f)(1.0) == 1.0
assert grad(f)(-1.0) == -1.0
with self.assertRaisesRegex(core.ConcretizationTypeError,
"Abstract tracer value"):
"Attempted boolean conversion"):
jit(f)(1)
def test_list_index_err(self):