mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6612 from google:tracer-errors
PiperOrigin-RevId: 372211269
This commit is contained in:
commit
3c6a41eb9c
@ -2400,9 +2400,10 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
"""
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug_info = pe.debug_info(fun, in_tree, True, "eval_shape")
|
||||
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
|
||||
*map(shaped_abstractify, args_flat),
|
||||
transform_name="eval_shape")
|
||||
debug_info=debug_info)
|
||||
out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
@ -23,7 +23,7 @@ class _JAXErrorMixin:
|
||||
error_page = self._error_page
|
||||
module_name = self._module_name
|
||||
class_name = self.__class__.__name__
|
||||
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
|
||||
error_msg = f'{message}See {error_page}#{module_name}.{class_name}'
|
||||
# https://github.com/python/mypy/issues/5887
|
||||
super().__init__(error_msg) # type: ignore
|
||||
|
||||
@ -38,14 +38,15 @@ class JAXIndexError(_JAXErrorMixin, IndexError):
|
||||
|
||||
class ConcretizationTypeError(JAXTypeError):
|
||||
"""
|
||||
This error occurs when a JAX Tracer object is used in a context where a concrete value
|
||||
is required. In some situations, it can be easily fixed by marking problematic values
|
||||
as static; in others, it may indicate that your program is doing operations that are
|
||||
not directly supported by JAX's JIT compilation model.
|
||||
This error occurs when a JAX Tracer object is used in a context where a
|
||||
concrete value is required. In some situations, it can be easily fixed by
|
||||
marking problematic values as static; in others, it may indicate that your
|
||||
program is doing operations that are not directly supported by JAX's JIT
|
||||
compilation model.
|
||||
|
||||
Traced value where static value is expected
|
||||
One common cause of this error is using a traced value where a static value is required.
|
||||
For example:
|
||||
One common cause of this error is using a traced value where a static value
|
||||
is required. For example:
|
||||
|
||||
>>> from jax import jit, partial
|
||||
>>> import jax.numpy as jnp
|
||||
@ -56,10 +57,10 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
>>> func(jnp.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
axis argument to jnp.min().
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete
|
||||
value is expected: axis argument to jnp.min().
|
||||
|
||||
This can often be fixed by marking the problematic value as static::
|
||||
This can often be fixed by marking the problematic argument as static::
|
||||
|
||||
>>> @partial(jit, static_argnums=1)
|
||||
... def func(x, axis):
|
||||
@ -69,8 +70,8 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
DeviceArray(0, dtype=int32)
|
||||
|
||||
Traced value used in control flow
|
||||
Another case where this often arises is when a traced value is used in Python control flow.
|
||||
For example::
|
||||
Another case where this often arises is when a traced value is used in
|
||||
Python control flow. For example::
|
||||
|
||||
>>> @jit
|
||||
... def func(x, y):
|
||||
@ -79,12 +80,12 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
The problem arose with the `bool` function.
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete
|
||||
value is expected: [...]
|
||||
|
||||
In this case, marking the problematic traced quantity as static is not an option, because it
|
||||
is derived from traced inputs. But you can make progress by re-expressing this if statement
|
||||
in terms of :func:`jax.numpy.where`::
|
||||
We could mark both inputs ``x`` and ``y`` as static, but that would defeat
|
||||
the purpose of using :func:`jax.jit` here. Another option is to re-express
|
||||
the if statement in terms of :func:`jax.numpy.where`::
|
||||
|
||||
>>> @jit
|
||||
... def func(x, y):
|
||||
@ -93,11 +94,12 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
>>> func(jnp.ones(4), jnp.zeros(4))
|
||||
DeviceArray([0., 0., 0., 0.], dtype=float32)
|
||||
|
||||
For more complicated control flow including loops, see :ref:`lax-control-flow`.
|
||||
For more complicated control flow including loops, see
|
||||
:ref:`lax-control-flow`.
|
||||
|
||||
Shape depends on Traced Value
|
||||
Such an error may also arise when a shape in your JIT-compiled computation depends
|
||||
on the values within a traced quantity. For example::
|
||||
Such an error may also arise when a shape in your JIT-compiled computation
|
||||
depends on the values within a traced quantity. For example::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
@ -109,12 +111,13 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
The error arose in jnp.nonzero.
|
||||
|
||||
This is an example of an operation that is incompatible with JAX's JIT compilation model,
|
||||
which requires array sizes to be known at compile-time. Here the size of the returned
|
||||
array depends on the contents of `x`, and such code cannot be JIT compiled.
|
||||
This is an example of an operation that is incompatible with JAX's JIT
|
||||
compilation model, which requires array sizes to be known at compile-time.
|
||||
Here the size of the returned array depends on the contents of `x`, and such
|
||||
code cannot be JIT compiled.
|
||||
|
||||
In many cases it is possible to work around this by modifying the logic used in the function;
|
||||
for example here is code with a similar issue::
|
||||
In many cases it is possible to work around this by modifying the logic used
|
||||
in the function; for example here is code with a similar issue::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
@ -124,11 +127,11 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
The error arose in jnp.nonzero.
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete
|
||||
value is expected: The error arose in jnp.nonzero.
|
||||
|
||||
And here is how you might express the same operation in a way that avoids creation of a
|
||||
dynamically-sized index array::
|
||||
And here is how you might express the same operation in a way that avoids
|
||||
creation of a dynamically-sized index array::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
@ -137,29 +140,31 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
>>> func(jnp.arange(4))
|
||||
DeviceArray(5, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and
|
||||
concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
To understand more subtleties having to do with tracers vs. regular values,
|
||||
and concrete vs. abstract values, you may want to read
|
||||
:ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer", context: str = ""):
|
||||
super().__init__(
|
||||
"Abstract tracer value encountered where concrete value is expected: "
|
||||
f"{tracer}\n{context}\n{tracer._origin_msg()}\n")
|
||||
f"{tracer}\n{context}{tracer._origin_msg()}\n")
|
||||
|
||||
|
||||
class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
"""
|
||||
This error occurs when a program attempts to use non-concrete boolean indices
|
||||
in a traced indexing operation. Under JIT compilation, JAX arrays must have static
|
||||
shapes (i.e. shapes that are known at compile-time) and so boolean masks must be
|
||||
used carefully. Some logic implemented via boolean masking is simply not possible
|
||||
under JAX's JIT compilation model; in other cases, the logic can be re-expressed in
|
||||
a JIT-compatible way, often using the three-argument version of :func:`~jax.numpy.where`.
|
||||
in a traced indexing operation. Under JIT compilation, JAX arrays must have
|
||||
static shapes (i.e. shapes that are known at compile-time) and so boolean
|
||||
masks must be used carefully. Some logic implemented via boolean masking is
|
||||
simply not possible in a :func:`jax.jit` function; in other cases, the logic
|
||||
can be re-expressed in a JIT-compatible way, often using the three-argument
|
||||
version of :func:`~jax.numpy.where`.
|
||||
|
||||
Following are a few examples of when this error might arise.
|
||||
|
||||
Constructing arrays via boolean masking
|
||||
This most commonly arises when attempting to create an array via a boolean mask
|
||||
within a JIT context. For example::
|
||||
This most commonly arises when attempting to create an array via a boolean
|
||||
mask within a JIT context. For example::
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
@ -173,14 +178,16 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
...
|
||||
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
|
||||
|
||||
This function is attempting to return only the positive values in the input array; the size of
|
||||
this returned array cannot be determined at compile-time unless `x` is marked as static, and so
|
||||
operations like this cannot be performed under JIT compilation.
|
||||
This function is attempting to return only the positive values in the input
|
||||
array; the size of this returned array cannot be determined at compile-time
|
||||
unless `x` is marked as static, and so operations like this cannot be
|
||||
performed under JIT compilation.
|
||||
|
||||
Reexpressible Boolean Logic
|
||||
Although creating dynamically sized arrays is not supported directly, in many cases it is
|
||||
possible to re-express the logic of the computation in terms of a JIT-compatible operation.
|
||||
For example, here is another function that fails under JIT for the same reason::
|
||||
Although creating dynamically sized arrays is not supported directly, in
|
||||
many cases it is possible to re-express the logic of the computation in
|
||||
terms of a JIT-compatible operation. For example, here is another function
|
||||
that fails under JIT for the same reason::
|
||||
|
||||
>>> @jax.jit
|
||||
... def sum_of_positive(x):
|
||||
@ -191,9 +198,9 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
...
|
||||
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
|
||||
|
||||
In this case, however, the problematic array is only an intermediate value, and we can
|
||||
instead express the same logic in terms of the JIT-compatible three-argument version of
|
||||
:func:`jax.numpy.where`::
|
||||
In this case, however, the problematic array is only an intermediate value,
|
||||
and we can instead express the same logic in terms of the JIT-compatible
|
||||
three-argument version of :func:`jax.numpy.where`::
|
||||
|
||||
>>> @jax.jit
|
||||
... def sum_of_positive(x):
|
||||
@ -202,12 +209,13 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
>>> sum_of_positive(jnp.arange(-5, 5))
|
||||
DeviceArray(10, dtype=int32)
|
||||
|
||||
This pattern of replacing boolean masking with three-argument :func:`~jax.numpy.where` is a
|
||||
common solution to this sort of problem.
|
||||
This pattern of replacing boolean masking with three-argument
|
||||
:func:`~jax.numpy.where` is a common solution to this sort of problem.
|
||||
|
||||
Boolean indices in :mod:`jax.ops`
|
||||
The other situation where this error often arises is when using boolean indices within functions
|
||||
in :mod:`jax.ops`, such as :func:`jax.ops.index_update`. Here is a simple example::
|
||||
The other situation where this error often arises is when using boolean
|
||||
indices within functions in :mod:`jax.ops`, such as
|
||||
:func:`jax.ops.index_update`. Here is a simple example::
|
||||
|
||||
>>> @jax.jit
|
||||
... def manual_clip(x):
|
||||
@ -218,8 +226,9 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
...
|
||||
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
|
||||
|
||||
This function is attempting to set values smaller than zero to a scalar fill value. As above,
|
||||
this can be addressed by re-expressing the logic in terms of :func:`~jax.numpy.where`::
|
||||
This function is attempting to set values smaller than zero to a scalar fill
|
||||
value. As above, this can be addressed by re-expressing the logic in terms
|
||||
of :func:`~jax.numpy.where`::
|
||||
|
||||
>>> @jax.jit
|
||||
... def manual_clip(x):
|
||||
@ -228,8 +237,9 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
>>> manual_clip(jnp.arange(-2, 2))
|
||||
DeviceArray([0, 0, 0, 1], dtype=int32)
|
||||
|
||||
These operations also commonly are written in terms of the :ref:`syntactic-sugar-for-ops`;
|
||||
for example, this is syntactic sugar for :func:`~jax.ops.index_mul`, and fails under JIT::
|
||||
These operations also commonly are written in terms of the
|
||||
:ref:`syntactic-sugar-for-ops`; for example, this is syntactic sugar for
|
||||
:func:`~jax.ops.index_mul`, and fails under JIT::
|
||||
|
||||
>>> @jax.jit
|
||||
... def manual_abs(x):
|
||||
@ -240,7 +250,8 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
...
|
||||
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
|
||||
|
||||
As above, the solution is to re-express this in terms of :func:`~jax.numpy.where`::
|
||||
As above, the solution is to re-express this in terms of
|
||||
:func:`~jax.numpy.where`::
|
||||
|
||||
>>> @jax.jit
|
||||
... def manual_abs(x):
|
||||
@ -256,12 +267,12 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
|
||||
class TracerArrayConversionError(JAXTypeError):
|
||||
"""
|
||||
This error occurs when a program attempts to convert a JAX Tracer object into a
|
||||
standard NumPy array. It typically occurs in one of a few situations.
|
||||
This error occurs when a program attempts to convert a JAX Tracer object into
|
||||
a standard NumPy array. It typically occurs in one of a few situations.
|
||||
|
||||
Using `numpy` rather than `jax.numpy` functions
|
||||
This error can occur when a JAX Tracer object is passed to a raw numpy function,
|
||||
or a method on a numpy.ndarray object. For example::
|
||||
This error can occur when a JAX Tracer object is passed to a raw numpy
|
||||
function, or a method on a numpy.ndarray object. For example::
|
||||
|
||||
>>> from jax import jit, partial
|
||||
>>> import numpy as np
|
||||
@ -274,9 +285,11 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
|
||||
TracerArrayConversionError: The numpy.ndarray conversion method
|
||||
__array__() was called on the JAX Tracer object
|
||||
|
||||
In this case, check that you are using `jax.numpy` methods rather than `numpy` methods::
|
||||
In this case, check that you are using `jax.numpy` methods rather than
|
||||
`numpy` methods::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
@ -286,8 +299,9 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
|
||||
|
||||
Indexing a numpy array with a tracer
|
||||
If this error arises on a line that involves array indexing, it may be that the array being
|
||||
indexed `x` is a raw numpy.ndarray while the indices `idx` are traced. For example::
|
||||
If this error arises on a line that involves array indexing, it may be that
|
||||
the array being indexed `x` is a raw numpy.ndarray while the indices `idx`
|
||||
are traced. For example::
|
||||
|
||||
>>> x = np.arange(10)
|
||||
|
||||
@ -298,9 +312,11 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
|
||||
TracerArrayConversionError: The numpy.ndarray conversion method
|
||||
__array__() was called on the JAX Tracer object
|
||||
|
||||
Depending on the context, you may fix this by converting the numpy array into a JAX array::
|
||||
Depending on the context, you may fix this by converting the numpy array
|
||||
into a JAX array::
|
||||
|
||||
>>> @jit
|
||||
... def func(i):
|
||||
@ -318,24 +334,24 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
>>> func(0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
|
||||
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
To understand more subtleties having to do with tracers vs. regular values,
|
||||
and concrete vs. abstract values, you may want to read
|
||||
:ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
# TODO(mattjj, jakevdp): use tracer._origin_msg() here
|
||||
super().__init__(
|
||||
"The numpy.ndarray conversion method __array__() was called on "
|
||||
f"the JAX Tracer object {tracer}")
|
||||
f"the JAX Tracer object {tracer}{tracer._origin_msg()}")
|
||||
|
||||
|
||||
class TracerIntegerConversionError(JAXTypeError):
|
||||
"""
|
||||
This error can occur when a JAX Tracer object is used in a context where a Python integer
|
||||
is expected. It typically occurs in a few situations.
|
||||
This error can occur when a JAX Tracer object is used in a context where a
|
||||
Python integer is expected. It typically occurs in a few situations.
|
||||
|
||||
Passing a tracer in place of an integer
|
||||
This error can occur if you attempt to pass a tracer to a function that requires an integer
|
||||
argument; for example::
|
||||
This error can occur if you attempt to pass a tracer to a function that
|
||||
requires an integer argument; for example::
|
||||
|
||||
>>> from jax import jit, partial
|
||||
>>> import numpy as np
|
||||
@ -347,9 +363,11 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
>>> func(np.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
|
||||
TracerIntegerConversionError: The __index__() method was called on the JAX
|
||||
Tracer object
|
||||
|
||||
When this happens, the solution is often to mark the problematic argument as static::
|
||||
When this happens, the solution is often to mark the problematic argument as
|
||||
static::
|
||||
|
||||
>>> @partial(jit, static_argnums=1)
|
||||
... def func(x, axis):
|
||||
@ -359,17 +377,19 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
[DeviceArray([0, 1, 2, 3, 4], dtype=int32),
|
||||
DeviceArray([5, 6, 7, 8, 9], dtype=int32)]
|
||||
|
||||
An alternative is to apply the transformation to a closure that encapsulates the arguments
|
||||
to be protected, either manually as below or by using :func:`functools.partial`::
|
||||
An alternative is to apply the transformation to a closure that encapsulates
|
||||
the arguments to be protected, either manually as below or by using
|
||||
:func:`functools.partial`::
|
||||
|
||||
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
|
||||
[DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)]
|
||||
|
||||
**Note a new closure is created at every invocation, which defeats the compilation
|
||||
caching mechanism, which is why static_argnums is preferred.**
|
||||
**Note a new closure is created at every invocation, which defeats the
|
||||
compilation caching mechanism, which is why static_argnums is preferred.**
|
||||
|
||||
Indexing a list with a Tracer
|
||||
This error can occur if you attempt to index a Python list with a traced quantity.
|
||||
This error can occur if you attempt to index a Python list with a traced
|
||||
quantity.
|
||||
For example::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
@ -386,8 +406,8 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
...
|
||||
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
|
||||
|
||||
Depending on the context, you can generally fix this either by converting the list
|
||||
to a JAX array::
|
||||
Depending on the context, you can generally fix this either by converting
|
||||
the list to a JAX array::
|
||||
|
||||
>>> @jit
|
||||
... def func(i):
|
||||
@ -405,8 +425,9 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
>>> func(0)
|
||||
DeviceArray(1, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
|
||||
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
To understand more subtleties having to do with tracers vs. regular values,
|
||||
and concrete vs. abstract values, you may want to read
|
||||
:ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
super().__init__(
|
||||
|
@ -64,23 +64,23 @@ Array = Any
|
||||
|
||||
@cache()
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
transform_name: str = ""):
|
||||
primitive_name: Optional[str] = None):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals,
|
||||
transform_name=transform_name)
|
||||
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
transform_name: str = ""):
|
||||
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals,
|
||||
transform_name)
|
||||
primitive_name: Optional[str] = None):
|
||||
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals, primitive_name)
|
||||
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
return closed_jaxpr, consts, out_tree
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
||||
in_tree, in_avals):
|
||||
def _initial_style_jaxprs_with_common_consts(
|
||||
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
|
||||
# When staging the branches of a conditional into jaxprs, constants are
|
||||
# extracted from each branch and converted to jaxpr arguments. To use the
|
||||
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
||||
@ -88,8 +88,9 @@ def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
||||
# for each one, it makes another that accepts *all* constants, but only uses
|
||||
# those that it needs (dropping the rest).
|
||||
|
||||
jaxprs, all_consts, all_out_trees = unzip3(
|
||||
_initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs)
|
||||
jaxprs, all_consts, all_out_trees = \
|
||||
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
|
||||
for fun in funs)
|
||||
|
||||
newvar = core.gensym(jaxprs, suffix='_')
|
||||
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
|
||||
@ -195,13 +196,17 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
|
||||
# If we can specialize on the trip count, call scan instead of a while_loop
|
||||
# to enable efficient reverse-mode differentiation.
|
||||
try:
|
||||
lower_ = int(lower)
|
||||
upper_ = int(upper)
|
||||
except TypeError:
|
||||
use_scan = False
|
||||
if (isinstance(core.get_aval(lower), ConcreteArray) and
|
||||
isinstance(core.get_aval(upper), ConcreteArray)):
|
||||
try:
|
||||
lower_ = int(lower)
|
||||
upper_ = int(upper)
|
||||
except TypeError:
|
||||
use_scan = False
|
||||
else:
|
||||
use_scan = True
|
||||
else:
|
||||
use_scan = True
|
||||
use_scan = False
|
||||
|
||||
if use_scan:
|
||||
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
|
||||
@ -273,8 +278,10 @@ def while_loop(cond_fun: Callable[[T], bool],
|
||||
def _create_jaxpr(init_val):
|
||||
init_vals, in_tree = tree_flatten((init_val,))
|
||||
init_avals = tuple(_map(_abstractify, init_vals))
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals, "while_cond")
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals, "while_loop")
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
|
||||
cond_fun, in_tree, init_avals, "while_cond")
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
|
||||
body_fun, in_tree, init_avals, "while_loop")
|
||||
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
||||
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
@ -609,7 +616,7 @@ def switch(index, branches: Sequence[Callable], operand):
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals)
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
@ -689,7 +696,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals)
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
out_tree, false_out_tree = out_trees
|
||||
|
||||
@ -1264,7 +1271,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan")
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, carry_avals + x_avals, "scan")
|
||||
out_tree_children = out_tree.children()
|
||||
if len(out_tree_children) != 2:
|
||||
msg = "scan body output must be a pair, got {}."
|
||||
@ -2182,11 +2190,13 @@ def custom_linear_solve(
|
||||
return f
|
||||
|
||||
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(matvec, "matvec"), in_args_tree, b_avals)
|
||||
_shape_checked(matvec, "matvec"), in_args_tree, b_avals,
|
||||
'custom_linear_solve')
|
||||
_check_tree("matvec", "b", out_tree, tree)
|
||||
|
||||
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals)
|
||||
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals,
|
||||
'custom_linear_solve')
|
||||
_check_tree("solve", "b", out_tree, tree)
|
||||
|
||||
if transpose_solve is None:
|
||||
@ -2200,12 +2210,12 @@ def custom_linear_solve(
|
||||
else:
|
||||
vecmat = _transpose_one_output(matvec, b)
|
||||
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
|
||||
vecmat, in_args_tree, b_avals)
|
||||
vecmat, in_args_tree, b_avals, 'custom_linear_solve')
|
||||
assert out_tree == tree
|
||||
|
||||
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve"),
|
||||
in_args_tree, b_avals)
|
||||
in_args_tree, b_avals, 'custom_linear_solve')
|
||||
_check_tree("transpose_solve", "b", out_tree, tree)
|
||||
|
||||
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
|
||||
|
@ -621,12 +621,13 @@ def power(x1, x2):
|
||||
# Using lax.pow may be imprecise for floating-point values; the goal of this
|
||||
# code path is to make sure we end up with a precise output for the common
|
||||
# pattern ``x ** 2`` or similar.
|
||||
try:
|
||||
x2 = core.concrete_or_error(operator.index, x2)
|
||||
except (core.ConcretizationTypeError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return lax.integer_pow(x1, x2)
|
||||
if isinstance(core.get_aval(x2), ConcreteArray):
|
||||
try:
|
||||
x2 = operator.index(x2)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
return lax.integer_pow(x1, x2)
|
||||
|
||||
x1, x2 = _promote_args("power", x1, x2)
|
||||
dtype = _dtype(x1)
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import operator
|
||||
from typing import Any, Dict, Iterable, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Tuple, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -85,6 +86,26 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
PyTreeDef = Any
|
||||
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
|
||||
# This implementation relies on internal details of linear_util.py's
|
||||
# WrappedFun, but it's for the worthy cause of better user error messages.
|
||||
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
||||
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
|
||||
# core.eval_jaxpr encounters a call primitive (though at that point we're just
|
||||
# round-tripping jaxprs and the user errors in question are impossible).
|
||||
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
|
||||
assert (isinstance(flatten_fun_nokwargs, partial) and
|
||||
len(flatten_fun_nokwargs.args) == 1)
|
||||
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
||||
try:
|
||||
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
|
||||
for f, args in fn.transforms if f in flat_xforms)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return in_tree, has_kwargs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
|
@ -451,12 +451,12 @@ def escaped_tracer_error(tracer, detail=None):
|
||||
'frames (most recent last) excluding JAX-internal frames were:\n'
|
||||
f'{source_info_util.summarize(line_info, num_frames=num_frames)}')
|
||||
try:
|
||||
fun_source_info = tracer._trace.main.source_info
|
||||
dbg = tracer._trace.main.debug_info
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
msg += ('\nThe function being traced when the tracer leaked was '
|
||||
f'{fun_source_info}.')
|
||||
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
|
||||
msg += ('\nTo catch the leak earlier, try setting the environment variable '
|
||||
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
|
||||
'manager.')
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import itertools as it
|
||||
from collections import namedtuple
|
||||
import contextlib
|
||||
@ -26,6 +27,8 @@ from .. import core
|
||||
from .._src import dtypes
|
||||
from .. import linear_util as lu
|
||||
from ..ad_util import Zero
|
||||
from ..api_util import flattened_fun_in_tree
|
||||
from .._src.tree_util import tree_unflatten, tree_leaves
|
||||
from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial,
|
||||
split_list, cache, as_hashable_function)
|
||||
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
@ -403,8 +406,9 @@ call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
||||
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def abstract_eval_fun(fun, *avals, transform_name="", **params):
|
||||
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals, transform_name)
|
||||
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
|
||||
_, avals_out, _ = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, params), avals, debug_info)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
return avals_out
|
||||
|
||||
@ -902,24 +906,25 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
|
||||
def _origin_msg(self):
|
||||
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
|
||||
dbg = self._trace.main.debug_info
|
||||
if invar_pos:
|
||||
origin = (f"While tracing the function {self._trace.main.source_info}, "
|
||||
origin = (f"While tracing the function {dbg.func_src_info} "
|
||||
f"for {dbg.traced_for}, "
|
||||
"this concrete value was not available in Python because it "
|
||||
"depends on the value of the arguments to "
|
||||
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
|
||||
"and the computation of these values is being staged out "
|
||||
"(that is, delayed rather than executed eagerly).")
|
||||
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
|
||||
f"of {dbg.arg_info(invar_pos)}.")
|
||||
elif progenitor_eqns:
|
||||
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
|
||||
f" from line {source_info_util.summarize(eqn.source_info)}"
|
||||
for eqn in progenitor_eqns]
|
||||
origin = (f"While tracing the function {self._trace.main.source_info}, "
|
||||
origin = (f"While tracing the function {dbg.func_src_info} "
|
||||
f"for {dbg.traced_for}, "
|
||||
"this value became a tracer due to JAX operations on these lines:"
|
||||
"\n\n" + "\n\n".join(msts))
|
||||
else:
|
||||
origin = ("The error occured while tracing the function "
|
||||
f"{self._trace.main.source_info}.")
|
||||
return origin
|
||||
origin = (f"The error occured while tracing the function {dbg.func_src_info} "
|
||||
f"for {dbg.traced_for}.")
|
||||
return "\n" + origin
|
||||
|
||||
def _assert_live(self) -> None:
|
||||
if not self._trace.main.jaxpr_stack: # type: ignore
|
||||
@ -1176,11 +1181,71 @@ def _memoize(thunk):
|
||||
return memoized
|
||||
|
||||
|
||||
class DebugInfo(NamedTuple):
|
||||
func_src_info: str
|
||||
traced_for: str
|
||||
arg_info: Callable[[int], str]
|
||||
|
||||
PyTreeDef = Any
|
||||
|
||||
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
|
||||
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
|
||||
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
|
||||
|
||||
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
|
||||
traced_for: str) -> DebugInfo:
|
||||
func_src_info = fun_sourceinfo(fn)
|
||||
if in_tree is not None:
|
||||
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
|
||||
else:
|
||||
arg_info = arg_info_flattened # type: ignore
|
||||
return DebugInfo(func_src_info, traced_for, arg_info)
|
||||
|
||||
def fun_sourceinfo(fun: Callable):
|
||||
while isinstance(fun, functools.partial):
|
||||
fun = fun.func
|
||||
try:
|
||||
filename = fun.__code__.co_filename
|
||||
lineno = fun.__code__.co_firstlineno
|
||||
line_info = f"{fun.__name__} at {filename}:{lineno}"
|
||||
return line_info
|
||||
except AttributeError:
|
||||
return "<unknown>"
|
||||
|
||||
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
|
||||
flat_pos: List[int]) -> str:
|
||||
dummy_args = [False] * in_tree.num_leaves
|
||||
for i in flat_pos: dummy_args[i] = True
|
||||
if has_kwargs:
|
||||
args, kwargs = tree_unflatten(in_tree, dummy_args)
|
||||
else:
|
||||
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
|
||||
try:
|
||||
ba = inspect.signature(fn).bind(*args, **kwargs)
|
||||
except (TypeError, ValueError):
|
||||
return arg_info_flattened(flat_pos)
|
||||
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
|
||||
if any(tree_leaves(x))]
|
||||
if len(arg_names) == 1:
|
||||
return f"the argument {arg_names[0]}"
|
||||
elif len(arg_names) == 2:
|
||||
return f"the arguments {arg_names[0]} and {arg_names[1]}"
|
||||
else:
|
||||
*rest, last = arg_names
|
||||
return f"the arguments {', '.join(rest)}, and {last}"
|
||||
|
||||
def arg_info_flattened(flat_pos: List[int]) -> str:
|
||||
if len(flat_pos) > 1:
|
||||
return f"the argument passed at flattened positions {flat_pos}"
|
||||
else:
|
||||
return f"the argument passed at flattened position {flat_pos[0]}"
|
||||
|
||||
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
transform_name: str = ""):
|
||||
debug_info: Optional[DebugInfo] = None):
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
|
||||
main.debug_info = debug_info # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
del main, fun
|
||||
@ -1209,9 +1274,9 @@ def extend_jaxpr_stack(main, frame):
|
||||
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
transform_name: str = ""):
|
||||
debug_info: Optional[DebugInfo] = None):
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
|
||||
main.debug_info = debug_info # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
@ -1225,16 +1290,3 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
|
||||
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
||||
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
def fun_sourceinfo(fun, transform_name: str = ""):
|
||||
if isinstance(fun, functools.partial):
|
||||
fun = fun.func
|
||||
try:
|
||||
filename = fun.__code__.co_filename
|
||||
lineno = fun.__code__.co_firstlineno
|
||||
line_info = f"{fun.__name__} at {filename}:{lineno}"
|
||||
if transform_name:
|
||||
line_info += f', transformed by {transform_name}.'
|
||||
return line_info
|
||||
except AttributeError:
|
||||
return "<unknown>"
|
||||
|
@ -696,7 +696,8 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
|
||||
|
||||
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_axes = out_axes_thunk()
|
||||
|
@ -653,8 +653,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
abstract_args, _ = unzip2(arg_specs)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
|
||||
abstract_args, arg_devices = unzip2(arg_specs)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, abstract_args, pe.debug_info_final(fun, "jit"))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
||||
|
@ -2224,17 +2224,17 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_escaped_tracer_transform_name(self):
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by jit"):
|
||||
"for jit"):
|
||||
jax.jit(self.helper_save_tracer)(1)
|
||||
_ = self._saved_tracer+1
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by pmap"):
|
||||
"for pmap"):
|
||||
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
|
||||
_ = self._saved_tracer+1
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by eval_shape"):
|
||||
"for eval_shape"):
|
||||
jax.eval_shape(self.helper_save_tracer, 1)
|
||||
_ = self._saved_tracer+1
|
||||
|
||||
@ -2278,7 +2278,19 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
f() # doesn't crash
|
||||
|
||||
def test_concrete_error_because_arg(self):
|
||||
def test_concrete_error_because_arg_unary(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
if x > 0:
|
||||
return x
|
||||
else:
|
||||
return 0
|
||||
|
||||
msg = r"on the value of the argument 'x'"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1)
|
||||
|
||||
def test_concrete_error_because_arg_binary(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
if x > y:
|
||||
@ -2286,10 +2298,67 @@ class APITest(jtu.JaxTestCase):
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"at flattened positions \[0, 1\]"
|
||||
msg = r"on the values of the arguments 'x' and 'y'"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2)
|
||||
|
||||
def test_concrete_error_because_arg_ternary(self):
|
||||
@jax.jit
|
||||
def f(x, y, z):
|
||||
if x > z:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the arguments 'x' and 'z'"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2, 3)
|
||||
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2, z=3)
|
||||
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, y=2, z=3)
|
||||
|
||||
def test_concrete_error_because_arg_varargs(self):
|
||||
@jax.jit
|
||||
def f(*args):
|
||||
x, y, z = args
|
||||
if x > z:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the argument 'args'"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(1, 2, 3)
|
||||
|
||||
def test_concrete_error_because_arg_kwargs(self):
|
||||
@jax.jit
|
||||
def f(**kwargs):
|
||||
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
|
||||
if x > z:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the values of the argument 'kwargs'"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f(x=1, y=2, z=3)
|
||||
|
||||
def test_concrete_error_because_arg_pytree(self):
|
||||
@jax.jit
|
||||
def f(xy, z):
|
||||
x, y = xy
|
||||
if x > 0:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
msg = r"on the value of the argument 'xy'"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
f((1, 2), z=3)
|
||||
|
||||
def test_concrete_error_because_const(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
|
@ -2617,7 +2617,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def test_unexpected_tracer_error(self):
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by while_loop"):
|
||||
"for while_loop"):
|
||||
lst = []
|
||||
def side_effecting_body(val):
|
||||
lst.append(val)
|
||||
@ -2626,7 +2626,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lst[0] += 1
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by scan"):
|
||||
"for scan"):
|
||||
lst = []
|
||||
def side_effecting_scan(carry, val):
|
||||
lst.append(val)
|
||||
|
Loading…
x
Reference in New Issue
Block a user