mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
684 lines
24 KiB
Python
684 lines
24 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
from jax._src import core
|
|
from jax._src.util import set_module
|
|
|
|
export = set_module('jax.errors')
|
|
|
|
|
|
class _JAXErrorMixin:
|
|
"""Mixin for JAX-specific errors"""
|
|
_error_page = 'https://jax.readthedocs.io/en/latest/errors.html'
|
|
_module_name = "jax.errors"
|
|
|
|
def __init__(self, message: str):
|
|
error_page = self._error_page
|
|
module_name = self._module_name
|
|
class_name = self.__class__.__name__
|
|
error_msg = f'{message}\nSee {error_page}#{module_name}.{class_name}'
|
|
# https://github.com/python/mypy/issues/5887
|
|
super().__init__(error_msg) # type: ignore
|
|
|
|
|
|
@export
|
|
class JAXTypeError(_JAXErrorMixin, TypeError):
|
|
pass
|
|
|
|
|
|
@export
|
|
class JAXIndexError(_JAXErrorMixin, IndexError):
|
|
pass
|
|
|
|
|
|
@export
|
|
class ConcretizationTypeError(JAXTypeError):
|
|
"""
|
|
This error occurs when a JAX Tracer object is used in a context where a
|
|
concrete value is required (see :ref:`faq-different-kinds-of-jax-values`
|
|
for more on what a Tracer is). In some situations, it can be easily fixed by
|
|
marking problematic values as static; in others, it may indicate that your
|
|
program is doing operations that are not directly supported by JAX's JIT
|
|
compilation model.
|
|
|
|
Examples:
|
|
|
|
Traced value where static value is expected
|
|
One common cause of this error is using a traced value where a static value
|
|
is required. For example:
|
|
|
|
>>> from functools import partial
|
|
>>> from jax import jit
|
|
>>> import jax.numpy as jnp
|
|
>>> @jit
|
|
... def func(x, axis):
|
|
... return x.min(axis)
|
|
|
|
>>> func(jnp.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
ConcretizationTypeError: Abstract tracer value encountered where concrete
|
|
value is expected: axis argument to jnp.min().
|
|
|
|
This can often be fixed by marking the problematic argument as static::
|
|
|
|
>>> @partial(jit, static_argnums=1)
|
|
... def func(x, axis):
|
|
... return x.min(axis)
|
|
|
|
>>> func(jnp.arange(4), 0)
|
|
Array(0, dtype=int32)
|
|
|
|
Shape depends on Traced Value
|
|
Such an error may also arise when a shape in your JIT-compiled computation
|
|
depends on the values within a traced quantity. For example::
|
|
|
|
>>> @jit
|
|
... def func(x):
|
|
... return jnp.where(x < 0)
|
|
|
|
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
|
The error arose in jnp.nonzero.
|
|
|
|
This is an example of an operation that is incompatible with JAX's JIT
|
|
compilation model, which requires array sizes to be known at compile-time.
|
|
Here the size of the returned array depends on the contents of `x`, and such
|
|
code cannot be JIT compiled.
|
|
|
|
In many cases it is possible to work around this by modifying the logic used
|
|
in the function; for example here is code with a similar issue::
|
|
|
|
>>> @jit
|
|
... def func(x):
|
|
... indices = jnp.where(x > 1)
|
|
... return x[indices].sum()
|
|
|
|
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
ConcretizationTypeError: Abstract tracer value encountered where concrete
|
|
value is expected: The error arose in jnp.nonzero.
|
|
|
|
And here is how you might express the same operation in a way that avoids
|
|
creation of a dynamically-sized index array::
|
|
|
|
>>> @jit
|
|
... def func(x):
|
|
... return jnp.where(x > 1, x, 0).sum()
|
|
|
|
>>> func(jnp.arange(4))
|
|
Array(5, dtype=int32)
|
|
|
|
To understand more subtleties having to do with tracers vs. regular values,
|
|
and concrete vs. abstract values, you may want to read
|
|
:ref:`faq-different-kinds-of-jax-values`.
|
|
"""
|
|
def __init__(self, tracer: core.Tracer, context: str = ""):
|
|
super().__init__(
|
|
"Abstract tracer value encountered where concrete value is expected: "
|
|
f"{tracer._error_repr()}\n{context}{tracer._origin_msg()}\n")
|
|
|
|
|
|
@export
|
|
class NonConcreteBooleanIndexError(JAXIndexError):
|
|
"""
|
|
This error occurs when a program attempts to use non-concrete boolean indices
|
|
in a traced indexing operation. Under JIT compilation, JAX arrays must have
|
|
static shapes (i.e. shapes that are known at compile-time) and so boolean
|
|
masks must be used carefully. Some logic implemented via boolean masking is
|
|
simply not possible in a :func:`jax.jit` function; in other cases, the logic
|
|
can be re-expressed in a JIT-compatible way, often using the three-argument
|
|
version of :func:`~jax.numpy.where`.
|
|
|
|
Following are a few examples of when this error might arise.
|
|
|
|
Constructing arrays via boolean masking
|
|
This most commonly arises when attempting to create an array via a boolean
|
|
mask within a JIT context. For example::
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> @jax.jit
|
|
... def positive_values(x):
|
|
... return x[x > 0]
|
|
|
|
>>> positive_values(jnp.arange(-5, 5)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
|
|
|
|
This function is attempting to return only the positive values in the input
|
|
array; the size of this returned array cannot be determined at compile-time
|
|
unless `x` is marked as static, and so operations like this cannot be
|
|
performed under JIT compilation.
|
|
|
|
Reexpressible Boolean Logic
|
|
Although creating dynamically sized arrays is not supported directly, in
|
|
many cases it is possible to re-express the logic of the computation in
|
|
terms of a JIT-compatible operation. For example, here is another function
|
|
that fails under JIT for the same reason::
|
|
|
|
>>> @jax.jit
|
|
... def sum_of_positive(x):
|
|
... return x[x > 0].sum()
|
|
|
|
>>> sum_of_positive(jnp.arange(-5, 5)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
|
|
|
|
In this case, however, the problematic array is only an intermediate value,
|
|
and we can instead express the same logic in terms of the JIT-compatible
|
|
three-argument version of :func:`jax.numpy.where`::
|
|
|
|
>>> @jax.jit
|
|
... def sum_of_positive(x):
|
|
... return jnp.where(x > 0, x, 0).sum()
|
|
|
|
>>> sum_of_positive(jnp.arange(-5, 5))
|
|
Array(10, dtype=int32)
|
|
|
|
This pattern of replacing boolean masking with three-argument
|
|
:func:`~jax.numpy.where` is a common solution to this sort of problem.
|
|
|
|
Boolean indexing into JAX arrays
|
|
The other situation where this error often arises is when using boolean
|
|
indices, such as with :code:`.at[...].set(...)`. Here is a simple example::
|
|
|
|
>>> @jax.jit
|
|
... def manual_clip(x):
|
|
... return x.at[x < 0].set(0)
|
|
|
|
>>> manual_clip(jnp.arange(-2, 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
|
|
|
|
This function is attempting to set values smaller than zero to a scalar fill
|
|
value. As above, this can be addressed by re-expressing the logic in terms
|
|
of :func:`~jax.numpy.where`::
|
|
|
|
>>> @jax.jit
|
|
... def manual_clip(x):
|
|
... return jnp.where(x < 0, 0, x)
|
|
|
|
>>> manual_clip(jnp.arange(-2, 2))
|
|
Array([0, 0, 0, 1], dtype=int32)
|
|
"""
|
|
def __init__(self, tracer: core.Tracer):
|
|
super().__init__(
|
|
f"Array boolean indices must be concrete; got {tracer}\n")
|
|
|
|
|
|
@export
|
|
class TracerArrayConversionError(JAXTypeError):
|
|
"""
|
|
This error occurs when a program attempts to convert a JAX Tracer object into
|
|
a standard NumPy array (see :ref:`faq-different-kinds-of-jax-values` for more
|
|
on what a Tracer is). It typically occurs in one of a few situations.
|
|
|
|
Using non-JAX functions in JAX transformations
|
|
This error can occur if you attempt to use a non-JAX library like ``numpy``
|
|
or ``scipy`` inside a JAX transformation (:func:`~jax.jit`, :func:`~jax.grad`,
|
|
:func:`jax.vmap`, etc.). For example::
|
|
|
|
>>> from jax import jit
|
|
>>> import numpy as np
|
|
|
|
>>> @jit
|
|
... def func(x):
|
|
... return np.sin(x)
|
|
|
|
>>> func(np.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerArrayConversionError: The numpy.ndarray conversion method
|
|
__array__() was called on traced array with shape int32[4]
|
|
|
|
In this case, you can fix the issue by using :func:`jax.numpy.sin` in place of
|
|
:func:`numpy.sin`::
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>> @jit
|
|
... def func(x):
|
|
... return jnp.sin(x)
|
|
|
|
>>> func(jnp.arange(4))
|
|
Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
|
|
|
|
See also `External Callbacks`_ for options for calling back to host-side computations
|
|
from transformed JAX code.
|
|
|
|
Indexing a numpy array with a tracer
|
|
If this error arises on a line that involves array indexing, it may be that
|
|
the array being indexed ``x`` is a standard numpy.ndarray while the indices
|
|
``idx`` are traced JAX arrays. For example::
|
|
|
|
>>> x = np.arange(10)
|
|
|
|
>>> @jit
|
|
... def func(i):
|
|
... return x[i]
|
|
|
|
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerArrayConversionError: The numpy.ndarray conversion method
|
|
__array__() was called on traced array with shape int32[0]
|
|
|
|
Depending on the context, you may fix this by converting the numpy array
|
|
into a JAX array::
|
|
|
|
>>> @jit
|
|
... def func(i):
|
|
... return jnp.asarray(x)[i]
|
|
|
|
>>> func(0)
|
|
Array(0, dtype=int32)
|
|
|
|
or by declaring the index as a static argument::
|
|
|
|
>>> from functools import partial
|
|
>>> @partial(jit, static_argnums=(0,))
|
|
... def func(i):
|
|
... return x[i]
|
|
|
|
>>> func(0)
|
|
Array(0, dtype=int32)
|
|
|
|
To understand more subtleties having to do with tracers vs. regular values,
|
|
and concrete vs. abstract values, you may want to read
|
|
:ref:`faq-different-kinds-of-jax-values`.
|
|
|
|
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
|
"""
|
|
def __init__(self, tracer: core.Tracer):
|
|
super().__init__(
|
|
"The numpy.ndarray conversion method __array__() was called on "
|
|
f"{tracer._error_repr()}{tracer._origin_msg()}")
|
|
|
|
|
|
@export
|
|
class TracerIntegerConversionError(JAXTypeError):
|
|
"""
|
|
This error can occur when a JAX Tracer object is used in a context where a
|
|
Python integer is expected (see :ref:`faq-different-kinds-of-jax-values` for
|
|
more on what a Tracer is). It typically occurs in a few situations.
|
|
|
|
Passing a tracer in place of an integer
|
|
This error can occur if you attempt to pass a traced value to a function
|
|
that requires a static integer argument; for example::
|
|
|
|
>>> from jax import jit
|
|
>>> import numpy as np
|
|
|
|
>>> @jit
|
|
... def func(x, axis):
|
|
... return np.split(x, 2, axis)
|
|
|
|
>>> func(np.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerIntegerConversionError: The __index__() method was called on
|
|
traced array with shape int32[0]
|
|
|
|
When this happens, the solution is often to mark the problematic argument as
|
|
static::
|
|
|
|
>>> from functools import partial
|
|
>>> @partial(jit, static_argnums=1)
|
|
... def func(x, axis):
|
|
... return np.split(x, 2, axis)
|
|
|
|
>>> func(np.arange(10), 0)
|
|
[Array([0, 1, 2, 3, 4], dtype=int32),
|
|
Array([5, 6, 7, 8, 9], dtype=int32)]
|
|
|
|
An alternative is to apply the transformation to a closure that encapsulates
|
|
the arguments to be protected, either manually as below or by using
|
|
:func:`functools.partial`::
|
|
|
|
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
|
|
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
|
|
|
|
**Note a new closure is created at every invocation, which defeats the
|
|
compilation caching mechanism, which is why static_argnums is preferred.**
|
|
|
|
Indexing a list with a Tracer
|
|
This error can occur if you attempt to index a Python list with a traced
|
|
quantity.
|
|
For example::
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>> from jax import jit
|
|
|
|
>>> L = [1, 2, 3]
|
|
|
|
>>> @jit
|
|
... def func(i):
|
|
... return L[i]
|
|
|
|
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerIntegerConversionError: The __index__() method was called on
|
|
traced array with shape int32[0]
|
|
|
|
Depending on the context, you can generally fix this either by converting
|
|
the list to a JAX array::
|
|
|
|
>>> @jit
|
|
... def func(i):
|
|
... return jnp.array(L)[i]
|
|
|
|
>>> func(0)
|
|
Array(1, dtype=int32)
|
|
|
|
or by declaring the index as a static argument::
|
|
|
|
>>> from functools import partial
|
|
>>> @partial(jit, static_argnums=0)
|
|
... def func(i):
|
|
... return L[i]
|
|
|
|
>>> func(0)
|
|
Array(1, dtype=int32, weak_type=True)
|
|
|
|
To understand more subtleties having to do with tracers vs. regular values,
|
|
and concrete vs. abstract values, you may want to read
|
|
:ref:`faq-different-kinds-of-jax-values`.
|
|
"""
|
|
def __init__(self, tracer: core.Tracer):
|
|
super().__init__(
|
|
f"The __index__() method was called on {tracer._error_repr()}"
|
|
f"{tracer._origin_msg()}")
|
|
|
|
|
|
@export
|
|
class TracerBoolConversionError(ConcretizationTypeError):
|
|
"""
|
|
This error occurs when a traced value in JAX is used in a context where a
|
|
boolean value is expected (see :ref:`faq-different-kinds-of-jax-values`
|
|
for more on what a Tracer is).
|
|
|
|
The boolean cast may be an explicit (e.g. ``bool(x)``) or implicit, through use of
|
|
control flow (e.g. ``if x > 0`` or ``while x``), use of Python boolean
|
|
operators (e.g. ``z = x and y``, ``z = x or y``, ``z = not x``) or functions
|
|
that use them (e.g. ``z = max(x, y)``, ``z = min(x, y)`` etc.).
|
|
|
|
In some situations, this problem can be easily fixed by marking traced values as
|
|
static; in others, it may indicate that your program is doing operations that are
|
|
not directly supported by JAX's JIT compilation model.
|
|
|
|
Examples:
|
|
|
|
Traced value used in control flow
|
|
One case where this often arises is when a traced value is used in
|
|
Python control flow. For example::
|
|
|
|
>>> from jax import jit
|
|
>>> import jax.numpy as jnp
|
|
>>> @jit
|
|
... def func(x, y):
|
|
... return x if x.sum() < y.sum() else y
|
|
|
|
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
|
|
|
|
We could mark both inputs ``x`` and ``y`` as static, but that would defeat
|
|
the purpose of using :func:`jax.jit` here. Another option is to re-express
|
|
the if statement in terms of the three-term :func:`jax.numpy.where`::
|
|
|
|
>>> @jit
|
|
... def func(x, y):
|
|
... return jnp.where(x.sum() < y.sum(), x, y)
|
|
|
|
>>> func(jnp.ones(4), jnp.zeros(4))
|
|
Array([0., 0., 0., 0.], dtype=float32)
|
|
|
|
For more complicated control flow including loops, see
|
|
:ref:`lax-control-flow`.
|
|
|
|
Control flow on traced values
|
|
Another common cause of this error is if you inadvertently trace over a boolean
|
|
flag. For example::
|
|
|
|
>>> @jit
|
|
... def func(x, normalize=True):
|
|
... if normalize:
|
|
... return x / x.sum()
|
|
... return x
|
|
|
|
>>> func(jnp.arange(5), True) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
|
|
|
|
Here because the flag ``normalize`` is traced, it cannot be used in Python
|
|
control flow. In this situation, the best solution is probably to mark this
|
|
value as static::
|
|
|
|
>>> from functools import partial
|
|
>>> @partial(jit, static_argnames=['normalize'])
|
|
... def func(x, normalize=True):
|
|
... if normalize:
|
|
... return x / x.sum()
|
|
... return x
|
|
|
|
>>> func(jnp.arange(5), True)
|
|
Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
|
|
|
|
For more on ``static_argnums``, see the documentation of :func:`jax.jit`.
|
|
|
|
Using non-JAX aware functions
|
|
Another common cause of this error is using non-JAX aware functions within JAX
|
|
code. For example:
|
|
|
|
>>> @jit
|
|
... def func(x):
|
|
... return min(x, 0)
|
|
|
|
>>> func(2) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
|
|
|
|
In this case, the error occurs because Python's built-in ``min`` function is not
|
|
compatible with JAX transforms. This can be fixed by replacing it with
|
|
``jnp.minumum``:
|
|
|
|
>>> @jit
|
|
... def func(x):
|
|
... return jnp.minimum(x, 0)
|
|
|
|
>>> print(func(2))
|
|
0
|
|
|
|
To understand more subtleties having to do with tracers vs. regular values,
|
|
and concrete vs. abstract values, you may want to read
|
|
:ref:`faq-different-kinds-of-jax-values`.
|
|
"""
|
|
def __init__(self, tracer: core.Tracer):
|
|
JAXTypeError.__init__(self,
|
|
f"Attempted boolean conversion of {tracer._error_repr()}."
|
|
f"{tracer._origin_msg()}")
|
|
|
|
|
|
@export
|
|
class UnexpectedTracerError(JAXTypeError):
|
|
"""
|
|
This error occurs when you use a JAX value that has leaked out of a function.
|
|
What does it mean to leak a value? If you use a JAX transformation on a
|
|
function ``f`` that stores, in some scope outside of ``f``, a reference to
|
|
an intermediate value, that value is considered to have been leaked.
|
|
Leaking values is a side effect. (Read more about avoiding side effects in
|
|
`Pure Functions <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_)
|
|
|
|
JAX detects leaks when you then use the leaked value in another
|
|
operation later on, at which point it raises an ``UnexpectedTracerError``.
|
|
To fix this, avoid side effects: if a function computes a value needed
|
|
in an outer scope, return that value from the transformed function explicitly.
|
|
|
|
Specifically, a ``Tracer`` is JAX's internal representation of a function's
|
|
intermediate values during transformations, e.g. within :func:`~jax.jit`,
|
|
:func:`~jax.pmap`, :func:`~jax.vmap`, etc. Encountering a ``Tracer`` outside
|
|
of a transformation implies a leak.
|
|
|
|
Life-cycle of a leaked value
|
|
Consider the following example of a transformed function which leaks a value
|
|
to an outer scope::
|
|
|
|
>>> from jax import jit
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> outs = []
|
|
>>> @jit # 1
|
|
... def side_effecting(x):
|
|
... y = x + 1 # 3
|
|
... outs.append(y) # 4
|
|
|
|
>>> x = 1
|
|
>>> side_effecting(x) # 2
|
|
>>> outs[0] + 1 # 5 # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
UnexpectedTracerError: Encountered an unexpected tracer.
|
|
|
|
In this example we leak a Traced value from an inner transformed scope to an
|
|
outer scope. We get an ``UnexpectedTracerError`` when the leaked value is
|
|
used, not when the value is leaked.
|
|
|
|
This example also demonstrates the life-cycle of a leaked value:
|
|
|
|
1. A function is transformed (in this case, by :func:`~jax.jit`)
|
|
2. The transformed function is called (initiating an abstract trace of the
|
|
function and turning ``x`` into a ``Tracer``)
|
|
3. The intermediate value ``y``, which will later be leaked, is created
|
|
(an intermediate value of a traced function is also a ``Tracer``)
|
|
4. The value is leaked (appended to a list in an outer scope, escaping
|
|
the function through a side-channel)
|
|
5. The leaked value is used, and an UnexpectedTracerError is raised.
|
|
|
|
The UnexpectedTracerError message tries to point to these locations in your
|
|
code by including information about each stage. Respectively:
|
|
|
|
1. The name of the transformed function (``side_effecting``) and which
|
|
transform kicked of the trace :func:`~jax.jit`).
|
|
2. A reconstructed stack trace of where the leaked Tracer was created,
|
|
which includes where the transformed function was called.
|
|
(``When the Tracer was created, the final 5 stack frames were...``).
|
|
3. From the reconstructed stack trace, the line of code that created
|
|
the leaked Tracer.
|
|
4. The leak location is not included in the error message because it is
|
|
difficult to pin down! JAX can only tell you what the leaked value
|
|
looks like (what shape is has and where it was created) and what
|
|
boundary it was leaked over (the name of the transformation and the
|
|
name of the transformed function).
|
|
5. The current error's stack trace points to where the value is used.
|
|
|
|
The error can be fixed by the returning the value out of the
|
|
transformed function::
|
|
|
|
>>> from jax import jit
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> outs = []
|
|
>>> @jit
|
|
... def not_side_effecting(x):
|
|
... y = x+1
|
|
... return y
|
|
|
|
>>> x = 1
|
|
>>> y = not_side_effecting(x)
|
|
>>> outs.append(y)
|
|
>>> outs[0] + 1 # all good! no longer a leaked value.
|
|
Array(3, dtype=int32, weak_type=True)
|
|
|
|
Leak checker
|
|
As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace
|
|
which points to where the leaked value was created. This is because
|
|
JAX only raises an error when the leaked value is used, not when the
|
|
value is leaked. This is not the most useful place to raise this error,
|
|
because you need to know the location where the Tracer was leaked to fix the
|
|
error.
|
|
|
|
To make this location easier to track down, you can use the leak checker.
|
|
When the leak checker is enabled, an error is raised as soon as a ``Tracer``
|
|
is leaked. (To be more exact, it will raise an error when the transformed
|
|
function from which the ``Tracer`` is leaked returns)
|
|
|
|
To enable the leak checker you can use the ``JAX_CHECK_TRACER_LEAKS``
|
|
environment variable or the ``with jax.checking_leaks()`` context manager.
|
|
|
|
.. note::
|
|
Note that this tool is experimental and may report false positives. It
|
|
works by disabling some JAX caches, so it will have a negative effect on
|
|
performance and should only be used when debugging.
|
|
|
|
Example usage::
|
|
|
|
>>> from jax import jit
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> outs = []
|
|
>>> @jit
|
|
... def side_effecting(x):
|
|
... y = x+1
|
|
... outs.append(y)
|
|
|
|
>>> x = 1
|
|
>>> with jax.checking_leaks():
|
|
... y = side_effecting(x) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
Exception: Leaked Trace
|
|
|
|
"""
|
|
|
|
def __init__(self, msg: str):
|
|
super().__init__(msg)
|
|
|
|
|
|
@export
|
|
class KeyReuseError(JAXTypeError):
|
|
"""
|
|
This error occurs when a PRNG key is reused in an unsafe manner.
|
|
Key reuse is checked only when `jax_enable_key_reuse_checks` is
|
|
set to `True`.
|
|
|
|
Here is a simple example of code that would lead to such an error::
|
|
|
|
>>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP
|
|
... key = jax.random.key(0)
|
|
... value = jax.random.uniform(key)
|
|
... new_value = jax.random.uniform(key)
|
|
...
|
|
---------------------------------------------------------------------------
|
|
KeyReuseError Traceback (most recent call last)
|
|
...
|
|
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
|
|
|
|
This sort of key reuse is problematic because the JAX PRNG is stateless, and keys
|
|
must be manually split; For more information on this see `Sharp Bits: Random Numbers
|
|
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers>`_.
|
|
"""
|
|
pass
|