Add jax.errors submodule & error troubleshooting docs

This commit is contained in:
Jake VanderPlas 2021-03-02 09:29:59 -08:00
parent 59fada9446
commit 12c84e7a50
13 changed files with 342 additions and 87 deletions

View File

@ -125,4 +125,4 @@ jobs:
- name: Test documentation
run: |
pytest -n 1 docs
pytest -n 1 --doctest-modules jax/api.py
pytest -n 1 --doctest-modules jax/api.py jax/_src/errors.py

11
docs/errors.rst Normal file
View File

@ -0,0 +1,11 @@
.. _jax-errors:
JAX Errors
==========
This page lists a few of the errors you might encounter when using JAX,
along with representative examples of how one might fix them.
.. currentmodule:: jax.errors
.. autoclass:: ConcretizationTypeError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerIntegerConversionError

View File

@ -218,51 +218,12 @@ JAX/accelerators vs NumPy/CPU. For example, if switch this example to use
``Abstract tracer value encountered where concrete value is expected`` error
----------------------------------------------------------------------------
See :class:`jax.errors.ConcretizationTypeError`
If you are getting an error that a library function is called with
*"Abstract tracer value encountered where concrete value is expected"*, you may need to
change how you invoke JAX transformations. Below is an example and a couple of possible
solutions, followed by the details of what is actually happening, if you are curious
or the simple solution does not work for you.
Some library functions take arguments that specify shapes or axes,
such as the second and third arguments for :func:`jax.numpy.split`::
# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
np.split(np.zeros(2), 2, 0) # works
If you try the following code::
jax.jit(np.split)(np.zeros(4), 2, 0)
you will get the following error::
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
You must change the way you use :func:`jax.jit` to ensure that the ``num_sections``
and ``axis`` arguments use their concrete values (``2`` and ``0`` respectively).
The best mechanism is to use special transformation parameters
to declare some arguments to be static, e.g., ``static_argnums`` for :func:`jax.jit`::
jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)
An alternative is to apply the transformation to a closure
that encapsulates the arguments to be protected, either manually as below
or by using ``functools.partial``::
jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))
**Note a new closure is created at every invocation, which defeats the
compilation caching mechanism, which is why static_argnums is preferred.**
To understand more subtleties having to do with tracers vs. regular values, and
concrete vs. abstract values, you may want to read `Different kinds of JAX values`_.
.. _faq-different-kinds-of-jax-values:
Different kinds of JAX values
------------------------------
-----------------------------
In the process of transforming functions, JAX replaces some function
arguments with special tracer values.

View File

@ -43,6 +43,7 @@ For an introduction to JAX, start at the
CHANGELOG
faq
errors
jaxpr
async_dispatch
concurrency

View File

@ -137,6 +137,7 @@ Operators
top_k
transpose
.. _lax-control-flow:
Control flow operators
----------------------

View File

@ -2,7 +2,7 @@ Internal APIs
=============
core
-----
----
.. currentmodule:: jax.core
.. automodule:: jax.core

View File

@ -91,6 +91,7 @@ from .version import __version__
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from . import errors
from . import image
from . import lax
from . import nn

292
jax/_src/errors.py Normal file
View File

@ -0,0 +1,292 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import core
class JAXTypeError(TypeError):
"""Base class for JAX-specific TypeErrors"""
def __init__(self, message: str):
error_page = 'https://jax.readthedocs.io/en/latest/errors.html'
module_name = self.__class__.__module__
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
super().__init__(error_msg)
class ConcretizationTypeError(JAXTypeError):
"""
This error occurs when a JAX Tracer object is used in a context where a concrete value
is required. In some situations, it can be easily fixed by marking problematic values
as static; in others, it may indicate that your program is doing operations that are
not directly supported by JAX's JIT compilation model.
Traced value where static value is expected
One common cause of this error is using a traced value where a static value is required.
For example:
>>> from jax import jit, partial
>>> import jax.numpy as jnp
>>> @jit
... def func(x, axis):
... return x.min(axis)
>>> func(jnp.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
axis argument to jnp.min().
This can often be fixed by marking the problematic value as static::
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
... return x.min(axis)
>>> func(jnp.arange(4), 0)
DeviceArray(0, dtype=int32)
Traced value used in control flow
Another case where this often arises is when a traced value is used in Python control flow.
For example::
>>> @jit
... def func(x, y):
... return x if x.sum() < y.sum() else y
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The problem arose with the `bool` function.
In this case, marking the problematic traced quantity as static is not an option, because it
is derived from traced inputs. But you can make progress by re-expressing this if statement
in terms of :func:`jax.numpy.where`::
>>> @jit
... def func(x, y):
... return jnp.where(x.sum() < y.sum(), x, y)
>>> func(jnp.ones(4), jnp.zeros(4))
DeviceArray([0., 0., 0., 0.], dtype=float32)
For more complicated control flow including loops, see :ref:`lax-control-flow`.
Shape depends on Traced Value
Such an error may also arise when a shape in your JIT-compiled computation depends
on the values within a traced quantity. For example::
>>> @jit
... def func(x):
... return jnp.where(x < 0)
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.
This is an example of an operation that is incompatible with JAX's JIT compilation model,
which requires array sizes to be known at compile-time. Here the size of the returned
array depends on the contents of `x`, and such code cannot be JIT compiled.
In many cases it is possible to work around this by modifying the logic used in the function;
for example here is code with a similar issue::
>>> @jit
... def func(x):
... indices = jnp.where(x > 1)
... return x[indices].sum()
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.
And here is how you might express the same operation in a way that avoids creation of a
dynamically-sized index array::
>>> @jit
... def func(x):
... return jnp.where(x > 1, x, 0).sum()
>>> func(jnp.arange(4))
DeviceArray(5, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and
concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer", context: str = ""):
super().__init__(
"Abstract tracer value encountered where concrete value is expected: "
f"{tracer}\n{context}\n{tracer._origin_msg()}\n")
class TracerArrayConversionError(JAXTypeError):
"""
This error occurs when a program attempts to convert a JAX Tracer object into a
standard NumPy array. It typically occurs in one of a few situations.
Using `numpy` rather than `jax.numpy` functions
This error can occur when a JAX Tracer object is passed to a raw numpy function,
or a method on a numpy.ndarray object. For example::
>>> from jax import jit, partial
>>> import numpy as np
>>> import jax.numpy as jnp
>>> @jit
... def func(x):
... return np.sin(x)
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
In this case, check that you are using `jax.numpy` methods rather than `numpy` methods::
>>> @jit
... def func(x):
... return jnp.sin(x)
>>> func(jnp.arange(4))
DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
Indexing a numpy array with a tracer
If this error arises on a line that involves array indexing, it may be that the array being
indexed `x` is a raw numpy.ndarray while the indices `idx` are traced. For example::
>>> x = np.arange(10)
>>> @jit
... def func(i):
... return x[i]
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
Depending on the context, you may fix this by converting the numpy array into a JAX array::
>>> @jit
... def func(i):
... return jnp.asarray(x)[i]
>>> func(0)
DeviceArray(0, dtype=int32)
or by declaring the index as a static argument::
>>> @partial(jit, static_argnums=(0,))
... def func(i):
... return x[i]
>>> func(0)
DeviceArray(0, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
super().__init__(
"The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {tracer}")
class TracerIntegerConversionError(JAXTypeError):
"""
This error can occur when a JAX Tracer object is used in a context where a Python integer
is expected. It typically occurs in a few situations.
Passing a tracer in place of an integer
This error can occur if you attempt to pass a tracer to a function that requires an integer
argument; for example::
>>> from jax import jit, partial
>>> import numpy as np
>>> @jit
... def func(x, axis):
... return np.split(x, 2, axis)
>>> func(np.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
When this happens, the solution is often to mark the problematic argument as static::
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
... return np.split(x, 2, axis)
>>> func(np.arange(10), 0)
[DeviceArray([0, 1, 2, 3, 4], dtype=int32),
DeviceArray([5, 6, 7, 8, 9], dtype=int32)]
An alternative is to apply the transformation to a closure that encapsulates the arguments
to be protected, either manually as below or by using :func:`functools.partial`::
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)]
**Note a new closure is created at every invocation, which defeats the compilation
caching mechanism, which is why static_argnums is preferred.**
Indexing a list with a Tracer
This error can occur if you attempt to index a Python list with a traced quantity.
For example::
>>> import jax.numpy as jnp
>>> from jax import jit, partial
>>> L = [1, 2, 3]
>>> @jit
... def func(i):
... return L[i]
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
Depending on the context, you can generally fix this either by converting the list
to a JAX array::
>>> @jit
... def func(i):
... return jnp.array(L)[i]
>>> func(0)
DeviceArray(1, dtype=int32)
or by declaring the index as a static argument::
>>> @partial(jit, static_argnums=0)
... def func(i):
... return L[i]
>>> func(0)
DeviceArray(1, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
super().__init__(
f"The __index__() method was called on the JAX Tracer object {tracer}")

View File

@ -30,6 +30,8 @@ import numpy as np
from . import dtypes
from .config import FLAGS, config
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError)
from . import linear_util as lu
from jax._src import source_info_util
@ -483,28 +485,10 @@ class Tracer:
__slots__ = ['_trace', '__weakref__', '_line_info']
def __array__(self, *args, **kw):
msg = ("The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {self}.\n\n"
"This error can occur when a JAX Tracer object is passed to a raw "
"numpy function, or a method on a numpy.ndarray object. You might "
"want to check that you are using `jnp` together with "
"`import jax.numpy as jnp` rather than using `np` via "
"`import numpy as np`. If this error arises on a line that involves "
"array indexing, like `x[idx]`, it may be that the array being "
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
"JAX Tracer instance; in that case, you can instead write "
"`jnp.asarray(x)[idx]`.")
raise Exception(msg)
raise TracerArrayConversionError(self)
def __index__(self):
msg = (f"The __index__ method was called on the JAX Tracer object {self}.\n\n"
"This error can occur when a JAX Tracer object is used in a context where "
"a Python integer is expected, such as an argument to the range() function, "
"or in index to a Python list. In the latter case, this can often be fixed "
"by converting the indexed object to a JAX array, for example by changing "
"`obj[idx]` to `jnp.asarray(obj)[idx]`."
)
raise TypeError(msg)
raise TracerIntegerConversionError(self)
def __init__(self, trace: Trace):
self._trace = trace
@ -956,17 +940,6 @@ unitvar = UnitVar()
pytype_aval_mappings[Unit] = lambda _: abstract_unit
class ConcretizationTypeError(TypeError): pass
def raise_concretization_error(val: Tracer, context=""):
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
+ context + "\n\n"
+ val._origin_msg() + "\n\n"
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
f"Encountered tracer value: {val}")
raise ConcretizationTypeError(msg)
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
fname_context = f"The problem arose with the `{fname}` function. "
@ -975,10 +948,9 @@ def concretization_function_error(fun, suggest_astype=False):
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
def error(self, arg):
raise_concretization_error(arg, fname_context)
raise ConcretizationTypeError(arg, fname_context)
return error
def concrete_or_error(force: Any, val: Any, context=""):
"""Like force(val), but gives the context in the error message."""
if force is None:
@ -987,7 +959,7 @@ def concrete_or_error(force: Any, val: Any, context=""):
if isinstance(val.aval, ConcreteArray):
return force(val.aval.val)
else:
raise_concretization_error(val, context)
raise ConcretizationTypeError(val, context)
else:
return force(val)

19
jax/errors.py Normal file
View File

@ -0,0 +1,19 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from ._src.errors import (JAXTypeError,
ConcretizationTypeError,
TracerArrayConversionError,
TracerIntegerConversionError)

View File

@ -922,10 +922,7 @@ class DynamicJaxprTracer(core.Tracer):
"depends on the value of the arguments to "
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
"and the computation of these values is being staged out "
"(that is, delayed rather than executed eagerly).\n\n"
"You can use transformation parameters such as `static_argnums` "
"for `jit` to avoid tracing particular arguments of transformed "
"functions, though at the cost of more recompiles.")
"(that is, delayed rather than executed eagerly).")
elif progenitor_eqns:
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"

View File

@ -658,7 +658,7 @@ class APITest(jtu.JaxTestCase):
assert jit(f, static_argnums=(0,))(0) == L[0]
self.assertRaisesRegex(
TypeError,
"The __index__ method was called on the JAX Tracer object.*",
r"The __index__\(\) method was called on the JAX Tracer object.*",
lambda: jit(f)(0))
def test_range_err(self):
@ -670,7 +670,7 @@ class APITest(jtu.JaxTestCase):
assert jit(f, static_argnums=(1,))(0, 5) == 10
self.assertRaisesRegex(
TypeError,
"The __index__ method was called on the JAX Tracer object.*",
r"The __index__\(\) method was called on the JAX Tracer object.*",
lambda: jit(f)(0, 5))
def test_cast_int(self):
@ -685,7 +685,7 @@ class APITest(jtu.JaxTestCase):
f = lambda x: castfun(x)
self.assertRaisesRegex(
TypeError,
"The __index__ method was called on the JAX Tracer object.*", lambda: jit(f)(0))
r"The __index__\(\) method was called on the JAX Tracer object.*", lambda: jit(f)(0))
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')

View File

@ -3228,7 +3228,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
x = x * x
return x
self.assertRaises(TypeError, lambda: f(3., 3))
self.assertRaises(jax.errors.TracerIntegerConversionError, lambda: f(3., 3))
@api.jit
def g(x):
@ -3237,7 +3237,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
else:
return x + 2
self.assertRaises(TypeError, lambda: g(3.))
self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.))
def testTracingPrimitiveWithNoTranslationErrorMessage(self):
# TODO(mattjj): update this for jax3