Merge pull request #6612 from google:tracer-errors

PiperOrigin-RevId: 372211269
This commit is contained in:
jax authors 2021-05-05 14:45:57 -07:00
commit 3c6a41eb9c
11 changed files with 336 additions and 159 deletions

View File

@ -2400,9 +2400,10 @@ def eval_shape(fun: Callable, *args, **kwargs):
"""
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug_info = pe.debug_info(fun, in_tree, True, "eval_shape")
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
*map(shaped_abstractify, args_flat),
transform_name="eval_shape")
debug_info=debug_info)
out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
return tree_unflatten(out_tree(), out)

View File

@ -23,7 +23,7 @@ class _JAXErrorMixin:
error_page = self._error_page
module_name = self._module_name
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
error_msg = f'{message}See {error_page}#{module_name}.{class_name}'
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore
@ -38,14 +38,15 @@ class JAXIndexError(_JAXErrorMixin, IndexError):
class ConcretizationTypeError(JAXTypeError):
"""
This error occurs when a JAX Tracer object is used in a context where a concrete value
is required. In some situations, it can be easily fixed by marking problematic values
as static; in others, it may indicate that your program is doing operations that are
not directly supported by JAX's JIT compilation model.
This error occurs when a JAX Tracer object is used in a context where a
concrete value is required. In some situations, it can be easily fixed by
marking problematic values as static; in others, it may indicate that your
program is doing operations that are not directly supported by JAX's JIT
compilation model.
Traced value where static value is expected
One common cause of this error is using a traced value where a static value is required.
For example:
One common cause of this error is using a traced value where a static value
is required. For example:
>>> from jax import jit, partial
>>> import jax.numpy as jnp
@ -56,10 +57,10 @@ class ConcretizationTypeError(JAXTypeError):
>>> func(jnp.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
axis argument to jnp.min().
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: axis argument to jnp.min().
This can often be fixed by marking the problematic value as static::
This can often be fixed by marking the problematic argument as static::
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
@ -69,8 +70,8 @@ class ConcretizationTypeError(JAXTypeError):
DeviceArray(0, dtype=int32)
Traced value used in control flow
Another case where this often arises is when a traced value is used in Python control flow.
For example::
Another case where this often arises is when a traced value is used in
Python control flow. For example::
>>> @jit
... def func(x, y):
@ -79,12 +80,12 @@ class ConcretizationTypeError(JAXTypeError):
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The problem arose with the `bool` function.
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: [...]
In this case, marking the problematic traced quantity as static is not an option, because it
is derived from traced inputs. But you can make progress by re-expressing this if statement
in terms of :func:`jax.numpy.where`::
We could mark both inputs ``x`` and ``y`` as static, but that would defeat
the purpose of using :func:`jax.jit` here. Another option is to re-express
the if statement in terms of :func:`jax.numpy.where`::
>>> @jit
... def func(x, y):
@ -93,11 +94,12 @@ class ConcretizationTypeError(JAXTypeError):
>>> func(jnp.ones(4), jnp.zeros(4))
DeviceArray([0., 0., 0., 0.], dtype=float32)
For more complicated control flow including loops, see :ref:`lax-control-flow`.
For more complicated control flow including loops, see
:ref:`lax-control-flow`.
Shape depends on Traced Value
Such an error may also arise when a shape in your JIT-compiled computation depends
on the values within a traced quantity. For example::
Such an error may also arise when a shape in your JIT-compiled computation
depends on the values within a traced quantity. For example::
>>> @jit
... def func(x):
@ -109,12 +111,13 @@ class ConcretizationTypeError(JAXTypeError):
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.
This is an example of an operation that is incompatible with JAX's JIT compilation model,
which requires array sizes to be known at compile-time. Here the size of the returned
array depends on the contents of `x`, and such code cannot be JIT compiled.
This is an example of an operation that is incompatible with JAX's JIT
compilation model, which requires array sizes to be known at compile-time.
Here the size of the returned array depends on the contents of `x`, and such
code cannot be JIT compiled.
In many cases it is possible to work around this by modifying the logic used in the function;
for example here is code with a similar issue::
In many cases it is possible to work around this by modifying the logic used
in the function; for example here is code with a similar issue::
>>> @jit
... def func(x):
@ -124,11 +127,11 @@ class ConcretizationTypeError(JAXTypeError):
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: The error arose in jnp.nonzero.
And here is how you might express the same operation in a way that avoids creation of a
dynamically-sized index array::
And here is how you might express the same operation in a way that avoids
creation of a dynamically-sized index array::
>>> @jit
... def func(x):
@ -137,29 +140,31 @@ class ConcretizationTypeError(JAXTypeError):
>>> func(jnp.arange(4))
DeviceArray(5, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and
concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer", context: str = ""):
super().__init__(
"Abstract tracer value encountered where concrete value is expected: "
f"{tracer}\n{context}\n{tracer._origin_msg()}\n")
f"{tracer}\n{context}{tracer._origin_msg()}\n")
class NonConcreteBooleanIndexError(JAXIndexError):
"""
This error occurs when a program attempts to use non-concrete boolean indices
in a traced indexing operation. Under JIT compilation, JAX arrays must have static
shapes (i.e. shapes that are known at compile-time) and so boolean masks must be
used carefully. Some logic implemented via boolean masking is simply not possible
under JAX's JIT compilation model; in other cases, the logic can be re-expressed in
a JIT-compatible way, often using the three-argument version of :func:`~jax.numpy.where`.
in a traced indexing operation. Under JIT compilation, JAX arrays must have
static shapes (i.e. shapes that are known at compile-time) and so boolean
masks must be used carefully. Some logic implemented via boolean masking is
simply not possible in a :func:`jax.jit` function; in other cases, the logic
can be re-expressed in a JIT-compatible way, often using the three-argument
version of :func:`~jax.numpy.where`.
Following are a few examples of when this error might arise.
Constructing arrays via boolean masking
This most commonly arises when attempting to create an array via a boolean mask
within a JIT context. For example::
This most commonly arises when attempting to create an array via a boolean
mask within a JIT context. For example::
>>> import jax
>>> import jax.numpy as jnp
@ -173,14 +178,16 @@ class NonConcreteBooleanIndexError(JAXIndexError):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
This function is attempting to return only the positive values in the input array; the size of
this returned array cannot be determined at compile-time unless `x` is marked as static, and so
operations like this cannot be performed under JIT compilation.
This function is attempting to return only the positive values in the input
array; the size of this returned array cannot be determined at compile-time
unless `x` is marked as static, and so operations like this cannot be
performed under JIT compilation.
Reexpressible Boolean Logic
Although creating dynamically sized arrays is not supported directly, in many cases it is
possible to re-express the logic of the computation in terms of a JIT-compatible operation.
For example, here is another function that fails under JIT for the same reason::
Although creating dynamically sized arrays is not supported directly, in
many cases it is possible to re-express the logic of the computation in
terms of a JIT-compatible operation. For example, here is another function
that fails under JIT for the same reason::
>>> @jax.jit
... def sum_of_positive(x):
@ -191,9 +198,9 @@ class NonConcreteBooleanIndexError(JAXIndexError):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
In this case, however, the problematic array is only an intermediate value, and we can
instead express the same logic in terms of the JIT-compatible three-argument version of
:func:`jax.numpy.where`::
In this case, however, the problematic array is only an intermediate value,
and we can instead express the same logic in terms of the JIT-compatible
three-argument version of :func:`jax.numpy.where`::
>>> @jax.jit
... def sum_of_positive(x):
@ -202,12 +209,13 @@ class NonConcreteBooleanIndexError(JAXIndexError):
>>> sum_of_positive(jnp.arange(-5, 5))
DeviceArray(10, dtype=int32)
This pattern of replacing boolean masking with three-argument :func:`~jax.numpy.where` is a
common solution to this sort of problem.
This pattern of replacing boolean masking with three-argument
:func:`~jax.numpy.where` is a common solution to this sort of problem.
Boolean indices in :mod:`jax.ops`
The other situation where this error often arises is when using boolean indices within functions
in :mod:`jax.ops`, such as :func:`jax.ops.index_update`. Here is a simple example::
The other situation where this error often arises is when using boolean
indices within functions in :mod:`jax.ops`, such as
:func:`jax.ops.index_update`. Here is a simple example::
>>> @jax.jit
... def manual_clip(x):
@ -218,8 +226,9 @@ class NonConcreteBooleanIndexError(JAXIndexError):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
This function is attempting to set values smaller than zero to a scalar fill value. As above,
this can be addressed by re-expressing the logic in terms of :func:`~jax.numpy.where`::
This function is attempting to set values smaller than zero to a scalar fill
value. As above, this can be addressed by re-expressing the logic in terms
of :func:`~jax.numpy.where`::
>>> @jax.jit
... def manual_clip(x):
@ -228,8 +237,9 @@ class NonConcreteBooleanIndexError(JAXIndexError):
>>> manual_clip(jnp.arange(-2, 2))
DeviceArray([0, 0, 0, 1], dtype=int32)
These operations also commonly are written in terms of the :ref:`syntactic-sugar-for-ops`;
for example, this is syntactic sugar for :func:`~jax.ops.index_mul`, and fails under JIT::
These operations also commonly are written in terms of the
:ref:`syntactic-sugar-for-ops`; for example, this is syntactic sugar for
:func:`~jax.ops.index_mul`, and fails under JIT::
>>> @jax.jit
... def manual_abs(x):
@ -240,7 +250,8 @@ class NonConcreteBooleanIndexError(JAXIndexError):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
As above, the solution is to re-express this in terms of :func:`~jax.numpy.where`::
As above, the solution is to re-express this in terms of
:func:`~jax.numpy.where`::
>>> @jax.jit
... def manual_abs(x):
@ -256,12 +267,12 @@ class NonConcreteBooleanIndexError(JAXIndexError):
class TracerArrayConversionError(JAXTypeError):
"""
This error occurs when a program attempts to convert a JAX Tracer object into a
standard NumPy array. It typically occurs in one of a few situations.
This error occurs when a program attempts to convert a JAX Tracer object into
a standard NumPy array. It typically occurs in one of a few situations.
Using `numpy` rather than `jax.numpy` functions
This error can occur when a JAX Tracer object is passed to a raw numpy function,
or a method on a numpy.ndarray object. For example::
This error can occur when a JAX Tracer object is passed to a raw numpy
function, or a method on a numpy.ndarray object. For example::
>>> from jax import jit, partial
>>> import numpy as np
@ -274,9 +285,11 @@ class TracerArrayConversionError(JAXTypeError):
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on the JAX Tracer object
In this case, check that you are using `jax.numpy` methods rather than `numpy` methods::
In this case, check that you are using `jax.numpy` methods rather than
`numpy` methods::
>>> @jit
... def func(x):
@ -286,8 +299,9 @@ class TracerArrayConversionError(JAXTypeError):
DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
Indexing a numpy array with a tracer
If this error arises on a line that involves array indexing, it may be that the array being
indexed `x` is a raw numpy.ndarray while the indices `idx` are traced. For example::
If this error arises on a line that involves array indexing, it may be that
the array being indexed `x` is a raw numpy.ndarray while the indices `idx`
are traced. For example::
>>> x = np.arange(10)
@ -298,9 +312,11 @@ class TracerArrayConversionError(JAXTypeError):
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on the JAX Tracer object
Depending on the context, you may fix this by converting the numpy array into a JAX array::
Depending on the context, you may fix this by converting the numpy array
into a JAX array::
>>> @jit
... def func(i):
@ -318,24 +334,24 @@ class TracerArrayConversionError(JAXTypeError):
>>> func(0)
DeviceArray(0, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
# TODO(mattjj, jakevdp): use tracer._origin_msg() here
super().__init__(
"The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {tracer}")
f"the JAX Tracer object {tracer}{tracer._origin_msg()}")
class TracerIntegerConversionError(JAXTypeError):
"""
This error can occur when a JAX Tracer object is used in a context where a Python integer
is expected. It typically occurs in a few situations.
This error can occur when a JAX Tracer object is used in a context where a
Python integer is expected. It typically occurs in a few situations.
Passing a tracer in place of an integer
This error can occur if you attempt to pass a tracer to a function that requires an integer
argument; for example::
This error can occur if you attempt to pass a tracer to a function that
requires an integer argument; for example::
>>> from jax import jit, partial
>>> import numpy as np
@ -347,9 +363,11 @@ class TracerIntegerConversionError(JAXTypeError):
>>> func(np.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
TracerIntegerConversionError: The __index__() method was called on the JAX
Tracer object
When this happens, the solution is often to mark the problematic argument as static::
When this happens, the solution is often to mark the problematic argument as
static::
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
@ -359,17 +377,19 @@ class TracerIntegerConversionError(JAXTypeError):
[DeviceArray([0, 1, 2, 3, 4], dtype=int32),
DeviceArray([5, 6, 7, 8, 9], dtype=int32)]
An alternative is to apply the transformation to a closure that encapsulates the arguments
to be protected, either manually as below or by using :func:`functools.partial`::
An alternative is to apply the transformation to a closure that encapsulates
the arguments to be protected, either manually as below or by using
:func:`functools.partial`::
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)]
**Note a new closure is created at every invocation, which defeats the compilation
caching mechanism, which is why static_argnums is preferred.**
**Note a new closure is created at every invocation, which defeats the
compilation caching mechanism, which is why static_argnums is preferred.**
Indexing a list with a Tracer
This error can occur if you attempt to index a Python list with a traced quantity.
This error can occur if you attempt to index a Python list with a traced
quantity.
For example::
>>> import jax.numpy as jnp
@ -386,8 +406,8 @@ class TracerIntegerConversionError(JAXTypeError):
...
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
Depending on the context, you can generally fix this either by converting the list
to a JAX array::
Depending on the context, you can generally fix this either by converting
the list to a JAX array::
>>> @jit
... def func(i):
@ -405,8 +425,9 @@ class TracerIntegerConversionError(JAXTypeError):
>>> func(0)
DeviceArray(1, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
super().__init__(

View File

@ -64,23 +64,23 @@ Array = Any
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
transform_name: str = ""):
primitive_name: Optional[str] = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals,
transform_name=transform_name)
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
transform_name: str = ""):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals,
transform_name)
primitive_name: Optional[str] = None):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
fun, in_tree, in_avals, primitive_name)
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
return closed_jaxpr, consts, out_tree
@cache()
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
in_tree, in_avals):
def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for
@ -88,8 +88,9 @@ def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_consts, all_out_trees = unzip3(
_initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs)
jaxprs, all_consts, all_out_trees = \
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
for fun in funs)
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
@ -195,13 +196,17 @@ def fori_loop(lower, upper, body_fun, init_val):
# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
use_scan = False
if (isinstance(core.get_aval(lower), ConcreteArray) and
isinstance(core.get_aval(upper), ConcreteArray)):
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
use_scan = False
else:
use_scan = True
else:
use_scan = True
use_scan = False
if use_scan:
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
@ -273,8 +278,10 @@ def while_loop(cond_fun: Callable[[T], bool],
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals, "while_loop")
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, "while_loop")
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
@ -609,7 +616,7 @@ def switch(index, branches: Sequence[Callable], operand):
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals)
branches, ops_tree, ops_avals, primitive_name='switch')
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
@ -689,7 +696,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals)
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
@ -1264,7 +1271,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan")
jaxpr, consts, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
@ -2182,11 +2190,13 @@ def custom_linear_solve(
return f
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec"), in_args_tree, b_avals)
_shape_checked(matvec, "matvec"), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("matvec", "b", out_tree, tree)
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals)
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("solve", "b", out_tree, tree)
if transpose_solve is None:
@ -2200,12 +2210,12 @@ def custom_linear_solve(
else:
vecmat = _transpose_one_output(matvec, b)
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
vecmat, in_args_tree, b_avals)
vecmat, in_args_tree, b_avals, 'custom_linear_solve')
assert out_tree == tree
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve"),
in_args_tree, b_avals)
in_args_tree, b_avals, 'custom_linear_solve')
_check_tree("transpose_solve", "b", out_tree, tree)
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]

View File

@ -621,12 +621,13 @@ def power(x1, x2):
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
try:
x2 = core.concrete_or_error(operator.index, x2)
except (core.ConcretizationTypeError, TypeError):
pass
else:
return lax.integer_pow(x1, x2)
if isinstance(core.get_aval(x2), ConcreteArray):
try:
x2 = operator.index(x2)
except TypeError:
pass
else:
return lax.integer_pow(x1, x2)
x1, x2 = _promote_args("power", x1, x2)
dtype = _dtype(x1)

View File

@ -13,7 +13,8 @@
# limitations under the License.
import operator
from typing import Any, Dict, Iterable, Tuple, Union
from functools import partial
from typing import Any, Dict, Iterable, Tuple, Union, Optional
import numpy as np
@ -85,6 +86,26 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
ans = fun(*args)
return tree_unflatten(out_tree, ans)
PyTreeDef = Any
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
try:
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
for f, args in fn.transforms if f in flat_xforms)
except ValueError:
return None
else:
return in_tree, has_kwargs
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)

View File

@ -451,12 +451,12 @@ def escaped_tracer_error(tracer, detail=None):
'frames (most recent last) excluding JAX-internal frames were:\n'
f'{source_info_util.summarize(line_info, num_frames=num_frames)}')
try:
fun_source_info = tracer._trace.main.source_info
dbg = tracer._trace.main.debug_info
except AttributeError:
pass
else:
msg += ('\nThe function being traced when the tracer leaked was '
f'{fun_source_info}.')
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
msg += ('\nTo catch the leak earlier, try setting the environment variable '
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
'manager.')

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import itertools as it
from collections import namedtuple
import contextlib
@ -26,6 +27,8 @@ from .. import core
from .._src import dtypes
from .. import linear_util as lu
from ..ad_util import Zero
from ..api_util import flattened_fun_in_tree
from .._src.tree_util import tree_unflatten, tree_leaves
from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial,
split_list, cache, as_hashable_function)
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
@ -403,8 +406,9 @@ call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_param_updaters: Dict[core.Primitive, Callable] = {}
def abstract_eval_fun(fun, *avals, transform_name="", **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals, transform_name)
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
@ -902,24 +906,25 @@ class DynamicJaxprTracer(core.Tracer):
def _origin_msg(self):
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
dbg = self._trace.main.debug_info
if invar_pos:
origin = (f"While tracing the function {self._trace.main.source_info}, "
origin = (f"While tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}, "
"this concrete value was not available in Python because it "
"depends on the value of the arguments to "
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
"and the computation of these values is being staged out "
"(that is, delayed rather than executed eagerly).")
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {dbg.arg_info(invar_pos)}.")
elif progenitor_eqns:
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
origin = (f"While tracing the function {self._trace.main.source_info}, "
origin = (f"While tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}, "
"this value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msts))
else:
origin = ("The error occured while tracing the function "
f"{self._trace.main.source_info}.")
return origin
origin = (f"The error occured while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}.")
return "\n" + origin
def _assert_live(self) -> None:
if not self._trace.main.jaxpr_stack: # type: ignore
@ -1176,11 +1181,71 @@ def _memoize(thunk):
return memoized
class DebugInfo(NamedTuple):
func_src_info: str
traced_for: str
arg_info: Callable[[int], str]
PyTreeDef = Any
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
func_src_info = fun_sourceinfo(fn)
if in_tree is not None:
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
else:
arg_info = arg_info_flattened # type: ignore
return DebugInfo(func_src_info, traced_for, arg_info)
def fun_sourceinfo(fun: Callable):
while isinstance(fun, functools.partial):
fun = fun.func
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
line_info = f"{fun.__name__} at {filename}:{lineno}"
return line_info
except AttributeError:
return "<unknown>"
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
flat_pos: List[int]) -> str:
dummy_args = [False] * in_tree.num_leaves
for i in flat_pos: dummy_args[i] = True
if has_kwargs:
args, kwargs = tree_unflatten(in_tree, dummy_args)
else:
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
try:
ba = inspect.signature(fn).bind(*args, **kwargs)
except (TypeError, ValueError):
return arg_info_flattened(flat_pos)
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
if any(tree_leaves(x))]
if len(arg_names) == 1:
return f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
return f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
return f"the arguments {', '.join(rest)}, and {last}"
def arg_info_flattened(flat_pos: List[int]) -> str:
if len(flat_pos) > 1:
return f"the argument passed at flattened positions {flat_pos}"
else:
return f"the argument passed at flattened position {flat_pos[0]}"
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
transform_name: str = ""):
debug_info: Optional[DebugInfo] = None):
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del main, fun
@ -1209,9 +1274,9 @@ def extend_jaxpr_stack(main, frame):
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
transform_name: str = ""):
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
@ -1225,16 +1290,3 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
return trace_to_jaxpr(fun, in_pvals)
def fun_sourceinfo(fun, transform_name: str = ""):
if isinstance(fun, functools.partial):
fun = fun.func
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
line_info = f"{fun.__name__} at {filename}:{lineno}"
if transform_name:
line_info += f', transformed by {transform_name}.'
return line_info
except AttributeError:
return "<unknown>"

View File

@ -696,7 +696,8 @@ def parallel_callable(fun: lu.WrappedFun,
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
out_axes = out_axes_thunk()

View File

@ -653,8 +653,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, _ = unzip2(arg_specs)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
abstract_args, arg_devices = unzip2(arg_specs)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)

View File

@ -2224,17 +2224,17 @@ class APITest(jtu.JaxTestCase):
def test_escaped_tracer_transform_name(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by jit"):
"for jit"):
jax.jit(self.helper_save_tracer)(1)
_ = self._saved_tracer+1
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by pmap"):
"for pmap"):
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
_ = self._saved_tracer+1
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by eval_shape"):
"for eval_shape"):
jax.eval_shape(self.helper_save_tracer, 1)
_ = self._saved_tracer+1
@ -2278,7 +2278,19 @@ class APITest(jtu.JaxTestCase):
f() # doesn't crash
def test_concrete_error_because_arg(self):
def test_concrete_error_because_arg_unary(self):
@jax.jit
def f(x):
if x > 0:
return x
else:
return 0
msg = r"on the value of the argument 'x'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)
def test_concrete_error_because_arg_binary(self):
@jax.jit
def f(x, y):
if x > y:
@ -2286,10 +2298,67 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"at flattened positions \[0, 1\]"
msg = r"on the values of the arguments 'x' and 'y'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)
def test_concrete_error_because_arg_ternary(self):
@jax.jit
def f(x, y, z):
if x > z:
return x
else:
return y
msg = r"on the values of the arguments 'x' and 'z'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, z=3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, y=2, z=3)
def test_concrete_error_because_arg_varargs(self):
@jax.jit
def f(*args):
x, y, z = args
if x > z:
return x
else:
return y
msg = r"on the values of the argument 'args'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
def test_concrete_error_because_arg_kwargs(self):
@jax.jit
def f(**kwargs):
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
if x > z:
return x
else:
return y
msg = r"on the values of the argument 'kwargs'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)
def test_concrete_error_because_arg_pytree(self):
@jax.jit
def f(xy, z):
x, y = xy
if x > 0:
return x
else:
return y
msg = r"on the value of the argument 'xy'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)
def test_concrete_error_because_const(self):
@jax.jit
def f():

View File

@ -2617,7 +2617,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def test_unexpected_tracer_error(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by while_loop"):
"for while_loop"):
lst = []
def side_effecting_body(val):
lst.append(val)
@ -2626,7 +2626,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lst[0] += 1
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by scan"):
"for scan"):
lst = []
def side_effecting_scan(carry, val):
lst.append(val)