mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add jax.errors submodule & error troubleshooting docs
This commit is contained in:
parent
59fada9446
commit
12c84e7a50
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -125,4 +125,4 @@ jobs:
|
||||
- name: Test documentation
|
||||
run: |
|
||||
pytest -n 1 docs
|
||||
pytest -n 1 --doctest-modules jax/api.py
|
||||
pytest -n 1 --doctest-modules jax/api.py jax/_src/errors.py
|
||||
|
11
docs/errors.rst
Normal file
11
docs/errors.rst
Normal file
@ -0,0 +1,11 @@
|
||||
.. _jax-errors:
|
||||
|
||||
JAX Errors
|
||||
==========
|
||||
This page lists a few of the errors you might encounter when using JAX,
|
||||
along with representative examples of how one might fix them.
|
||||
|
||||
.. currentmodule:: jax.errors
|
||||
.. autoclass:: ConcretizationTypeError
|
||||
.. autoclass:: TracerArrayConversionError
|
||||
.. autoclass:: TracerIntegerConversionError
|
45
docs/faq.rst
45
docs/faq.rst
@ -218,51 +218,12 @@ JAX/accelerators vs NumPy/CPU. For example, if switch this example to use
|
||||
|
||||
``Abstract tracer value encountered where concrete value is expected`` error
|
||||
----------------------------------------------------------------------------
|
||||
See :class:`jax.errors.ConcretizationTypeError`
|
||||
|
||||
If you are getting an error that a library function is called with
|
||||
*"Abstract tracer value encountered where concrete value is expected"*, you may need to
|
||||
change how you invoke JAX transformations. Below is an example and a couple of possible
|
||||
solutions, followed by the details of what is actually happening, if you are curious
|
||||
or the simple solution does not work for you.
|
||||
|
||||
Some library functions take arguments that specify shapes or axes,
|
||||
such as the second and third arguments for :func:`jax.numpy.split`::
|
||||
|
||||
# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
|
||||
np.split(np.zeros(2), 2, 0) # works
|
||||
|
||||
If you try the following code::
|
||||
|
||||
jax.jit(np.split)(np.zeros(4), 2, 0)
|
||||
|
||||
you will get the following error::
|
||||
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
|
||||
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
|
||||
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
|
||||
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
|
||||
|
||||
You must change the way you use :func:`jax.jit` to ensure that the ``num_sections``
|
||||
and ``axis`` arguments use their concrete values (``2`` and ``0`` respectively).
|
||||
The best mechanism is to use special transformation parameters
|
||||
to declare some arguments to be static, e.g., ``static_argnums`` for :func:`jax.jit`::
|
||||
|
||||
jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)
|
||||
|
||||
An alternative is to apply the transformation to a closure
|
||||
that encapsulates the arguments to be protected, either manually as below
|
||||
or by using ``functools.partial``::
|
||||
|
||||
jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))
|
||||
|
||||
**Note a new closure is created at every invocation, which defeats the
|
||||
compilation caching mechanism, which is why static_argnums is preferred.**
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and
|
||||
concrete vs. abstract values, you may want to read `Different kinds of JAX values`_.
|
||||
.. _faq-different-kinds-of-jax-values:
|
||||
|
||||
Different kinds of JAX values
|
||||
------------------------------
|
||||
-----------------------------
|
||||
|
||||
In the process of transforming functions, JAX replaces some function
|
||||
arguments with special tracer values.
|
||||
|
@ -43,6 +43,7 @@ For an introduction to JAX, start at the
|
||||
|
||||
CHANGELOG
|
||||
faq
|
||||
errors
|
||||
jaxpr
|
||||
async_dispatch
|
||||
concurrency
|
||||
|
@ -137,6 +137,7 @@ Operators
|
||||
top_k
|
||||
transpose
|
||||
|
||||
.. _lax-control-flow:
|
||||
|
||||
Control flow operators
|
||||
----------------------
|
||||
|
@ -2,7 +2,7 @@ Internal APIs
|
||||
=============
|
||||
|
||||
core
|
||||
-----
|
||||
----
|
||||
|
||||
.. currentmodule:: jax.core
|
||||
.. automodule:: jax.core
|
||||
|
@ -91,6 +91,7 @@ from .version import __version__
|
||||
|
||||
# These submodules are separate because they are in an import cycle with
|
||||
# jax and rely on the names imported above.
|
||||
from . import errors
|
||||
from . import image
|
||||
from . import lax
|
||||
from . import nn
|
||||
|
292
jax/_src/errors.py
Normal file
292
jax/_src/errors.py
Normal file
@ -0,0 +1,292 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax import core
|
||||
|
||||
|
||||
class JAXTypeError(TypeError):
|
||||
"""Base class for JAX-specific TypeErrors"""
|
||||
def __init__(self, message: str):
|
||||
error_page = 'https://jax.readthedocs.io/en/latest/errors.html'
|
||||
module_name = self.__class__.__module__
|
||||
class_name = self.__class__.__name__
|
||||
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
|
||||
super().__init__(error_msg)
|
||||
|
||||
|
||||
class ConcretizationTypeError(JAXTypeError):
|
||||
"""
|
||||
This error occurs when a JAX Tracer object is used in a context where a concrete value
|
||||
is required. In some situations, it can be easily fixed by marking problematic values
|
||||
as static; in others, it may indicate that your program is doing operations that are
|
||||
not directly supported by JAX's JIT compilation model.
|
||||
|
||||
Traced value where static value is expected
|
||||
One common cause of this error is using a traced value where a static value is required.
|
||||
For example:
|
||||
|
||||
>>> from jax import jit, partial
|
||||
>>> import jax.numpy as jnp
|
||||
>>> @jit
|
||||
... def func(x, axis):
|
||||
... return x.min(axis)
|
||||
|
||||
>>> func(jnp.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
axis argument to jnp.min().
|
||||
|
||||
This can often be fixed by marking the problematic value as static::
|
||||
|
||||
>>> @partial(jit, static_argnums=1)
|
||||
... def func(x, axis):
|
||||
... return x.min(axis)
|
||||
|
||||
>>> func(jnp.arange(4), 0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
|
||||
Traced value used in control flow
|
||||
Another case where this often arises is when a traced value is used in Python control flow.
|
||||
For example::
|
||||
|
||||
>>> @jit
|
||||
... def func(x, y):
|
||||
... return x if x.sum() < y.sum() else y
|
||||
|
||||
>>> func(jnp.ones(4), jnp.zeros(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
The problem arose with the `bool` function.
|
||||
|
||||
In this case, marking the problematic traced quantity as static is not an option, because it
|
||||
is derived from traced inputs. But you can make progress by re-expressing this if statement
|
||||
in terms of :func:`jax.numpy.where`::
|
||||
|
||||
>>> @jit
|
||||
... def func(x, y):
|
||||
... return jnp.where(x.sum() < y.sum(), x, y)
|
||||
|
||||
>>> func(jnp.ones(4), jnp.zeros(4))
|
||||
DeviceArray([0., 0., 0., 0.], dtype=float32)
|
||||
|
||||
For more complicated control flow including loops, see :ref:`lax-control-flow`.
|
||||
|
||||
Shape depends on Traced Value
|
||||
Such an error may also arise when a shape in your JIT-compiled computation depends
|
||||
on the values within a traced quantity. For example::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
... return jnp.where(x < 0)
|
||||
|
||||
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
The error arose in jnp.nonzero.
|
||||
|
||||
This is an example of an operation that is incompatible with JAX's JIT compilation model,
|
||||
which requires array sizes to be known at compile-time. Here the size of the returned
|
||||
array depends on the contents of `x`, and such code cannot be JIT compiled.
|
||||
|
||||
In many cases it is possible to work around this by modifying the logic used in the function;
|
||||
for example here is code with a similar issue::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
... indices = jnp.where(x > 1)
|
||||
... return x[indices].sum()
|
||||
|
||||
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
|
||||
The error arose in jnp.nonzero.
|
||||
|
||||
And here is how you might express the same operation in a way that avoids creation of a
|
||||
dynamically-sized index array::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
... return jnp.where(x > 1, x, 0).sum()
|
||||
|
||||
>>> func(jnp.arange(4))
|
||||
DeviceArray(5, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and
|
||||
concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer", context: str = ""):
|
||||
super().__init__(
|
||||
"Abstract tracer value encountered where concrete value is expected: "
|
||||
f"{tracer}\n{context}\n{tracer._origin_msg()}\n")
|
||||
|
||||
|
||||
class TracerArrayConversionError(JAXTypeError):
|
||||
"""
|
||||
This error occurs when a program attempts to convert a JAX Tracer object into a
|
||||
standard NumPy array. It typically occurs in one of a few situations.
|
||||
|
||||
Using `numpy` rather than `jax.numpy` functions
|
||||
This error can occur when a JAX Tracer object is passed to a raw numpy function,
|
||||
or a method on a numpy.ndarray object. For example::
|
||||
|
||||
>>> from jax import jit, partial
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
... return np.sin(x)
|
||||
|
||||
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
|
||||
|
||||
In this case, check that you are using `jax.numpy` methods rather than `numpy` methods::
|
||||
|
||||
>>> @jit
|
||||
... def func(x):
|
||||
... return jnp.sin(x)
|
||||
|
||||
>>> func(jnp.arange(4))
|
||||
DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
|
||||
|
||||
Indexing a numpy array with a tracer
|
||||
If this error arises on a line that involves array indexing, it may be that the array being
|
||||
indexed `x` is a raw numpy.ndarray while the indices `idx` are traced. For example::
|
||||
|
||||
>>> x = np.arange(10)
|
||||
|
||||
>>> @jit
|
||||
... def func(i):
|
||||
... return x[i]
|
||||
|
||||
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
|
||||
|
||||
Depending on the context, you may fix this by converting the numpy array into a JAX array::
|
||||
|
||||
>>> @jit
|
||||
... def func(i):
|
||||
... return jnp.asarray(x)[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
|
||||
or by declaring the index as a static argument::
|
||||
|
||||
>>> @partial(jit, static_argnums=(0,))
|
||||
... def func(i):
|
||||
... return x[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
|
||||
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
super().__init__(
|
||||
"The numpy.ndarray conversion method __array__() was called on "
|
||||
f"the JAX Tracer object {tracer}")
|
||||
|
||||
|
||||
class TracerIntegerConversionError(JAXTypeError):
|
||||
"""
|
||||
This error can occur when a JAX Tracer object is used in a context where a Python integer
|
||||
is expected. It typically occurs in a few situations.
|
||||
|
||||
Passing a tracer in place of an integer
|
||||
This error can occur if you attempt to pass a tracer to a function that requires an integer
|
||||
argument; for example::
|
||||
|
||||
>>> from jax import jit, partial
|
||||
>>> import numpy as np
|
||||
|
||||
>>> @jit
|
||||
... def func(x, axis):
|
||||
... return np.split(x, 2, axis)
|
||||
|
||||
>>> func(np.arange(4), 0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
|
||||
|
||||
When this happens, the solution is often to mark the problematic argument as static::
|
||||
|
||||
>>> @partial(jit, static_argnums=1)
|
||||
... def func(x, axis):
|
||||
... return np.split(x, 2, axis)
|
||||
|
||||
>>> func(np.arange(10), 0)
|
||||
[DeviceArray([0, 1, 2, 3, 4], dtype=int32),
|
||||
DeviceArray([5, 6, 7, 8, 9], dtype=int32)]
|
||||
|
||||
An alternative is to apply the transformation to a closure that encapsulates the arguments
|
||||
to be protected, either manually as below or by using :func:`functools.partial`::
|
||||
|
||||
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
|
||||
[DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)]
|
||||
|
||||
**Note a new closure is created at every invocation, which defeats the compilation
|
||||
caching mechanism, which is why static_argnums is preferred.**
|
||||
|
||||
Indexing a list with a Tracer
|
||||
This error can occur if you attempt to index a Python list with a traced quantity.
|
||||
For example::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import jit, partial
|
||||
|
||||
>>> L = [1, 2, 3]
|
||||
|
||||
>>> @jit
|
||||
... def func(i):
|
||||
... return L[i]
|
||||
|
||||
>>> func(0) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object
|
||||
|
||||
Depending on the context, you can generally fix this either by converting the list
|
||||
to a JAX array::
|
||||
|
||||
>>> @jit
|
||||
... def func(i):
|
||||
... return jnp.array(L)[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(1, dtype=int32)
|
||||
|
||||
or by declaring the index as a static argument::
|
||||
|
||||
>>> @partial(jit, static_argnums=0)
|
||||
... def func(i):
|
||||
... return L[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(1, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
|
||||
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
super().__init__(
|
||||
f"The __index__() method was called on the JAX Tracer object {tracer}")
|
40
jax/core.py
40
jax/core.py
@ -30,6 +30,8 @@ import numpy as np
|
||||
|
||||
from . import dtypes
|
||||
from .config import FLAGS, config
|
||||
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError)
|
||||
from . import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
@ -483,28 +485,10 @@ class Tracer:
|
||||
__slots__ = ['_trace', '__weakref__', '_line_info']
|
||||
|
||||
def __array__(self, *args, **kw):
|
||||
msg = ("The numpy.ndarray conversion method __array__() was called on "
|
||||
f"the JAX Tracer object {self}.\n\n"
|
||||
"This error can occur when a JAX Tracer object is passed to a raw "
|
||||
"numpy function, or a method on a numpy.ndarray object. You might "
|
||||
"want to check that you are using `jnp` together with "
|
||||
"`import jax.numpy as jnp` rather than using `np` via "
|
||||
"`import numpy as np`. If this error arises on a line that involves "
|
||||
"array indexing, like `x[idx]`, it may be that the array being "
|
||||
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
|
||||
"JAX Tracer instance; in that case, you can instead write "
|
||||
"`jnp.asarray(x)[idx]`.")
|
||||
raise Exception(msg)
|
||||
raise TracerArrayConversionError(self)
|
||||
|
||||
def __index__(self):
|
||||
msg = (f"The __index__ method was called on the JAX Tracer object {self}.\n\n"
|
||||
"This error can occur when a JAX Tracer object is used in a context where "
|
||||
"a Python integer is expected, such as an argument to the range() function, "
|
||||
"or in index to a Python list. In the latter case, this can often be fixed "
|
||||
"by converting the indexed object to a JAX array, for example by changing "
|
||||
"`obj[idx]` to `jnp.asarray(obj)[idx]`."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
raise TracerIntegerConversionError(self)
|
||||
|
||||
def __init__(self, trace: Trace):
|
||||
self._trace = trace
|
||||
@ -956,17 +940,6 @@ unitvar = UnitVar()
|
||||
|
||||
pytype_aval_mappings[Unit] = lambda _: abstract_unit
|
||||
|
||||
class ConcretizationTypeError(TypeError): pass
|
||||
|
||||
def raise_concretization_error(val: Tracer, context=""):
|
||||
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
|
||||
+ context + "\n\n"
|
||||
+ val._origin_msg() + "\n\n"
|
||||
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
|
||||
f"Encountered tracer value: {val}")
|
||||
raise ConcretizationTypeError(msg)
|
||||
|
||||
|
||||
def concretization_function_error(fun, suggest_astype=False):
|
||||
fname = getattr(fun, "__name__", fun)
|
||||
fname_context = f"The problem arose with the `{fname}` function. "
|
||||
@ -975,10 +948,9 @@ def concretization_function_error(fun, suggest_astype=False):
|
||||
f"try using `x.astype({fun.__name__})` "
|
||||
f"or `jnp.array(x, {fun.__name__})` instead.")
|
||||
def error(self, arg):
|
||||
raise_concretization_error(arg, fname_context)
|
||||
raise ConcretizationTypeError(arg, fname_context)
|
||||
return error
|
||||
|
||||
|
||||
def concrete_or_error(force: Any, val: Any, context=""):
|
||||
"""Like force(val), but gives the context in the error message."""
|
||||
if force is None:
|
||||
@ -987,7 +959,7 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
if isinstance(val.aval, ConcreteArray):
|
||||
return force(val.aval.val)
|
||||
else:
|
||||
raise_concretization_error(val, context)
|
||||
raise ConcretizationTypeError(val, context)
|
||||
else:
|
||||
return force(val)
|
||||
|
||||
|
19
jax/errors.py
Normal file
19
jax/errors.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from ._src.errors import (JAXTypeError,
|
||||
ConcretizationTypeError,
|
||||
TracerArrayConversionError,
|
||||
TracerIntegerConversionError)
|
@ -922,10 +922,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
"depends on the value of the arguments to "
|
||||
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
|
||||
"and the computation of these values is being staged out "
|
||||
"(that is, delayed rather than executed eagerly).\n\n"
|
||||
"You can use transformation parameters such as `static_argnums` "
|
||||
"for `jit` to avoid tracing particular arguments of transformed "
|
||||
"functions, though at the cost of more recompiles.")
|
||||
"(that is, delayed rather than executed eagerly).")
|
||||
elif progenitor_eqns:
|
||||
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
|
||||
f" from line {source_info_util.summarize(eqn.source_info)}"
|
||||
|
@ -658,7 +658,7 @@ class APITest(jtu.JaxTestCase):
|
||||
assert jit(f, static_argnums=(0,))(0) == L[0]
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The __index__ method was called on the JAX Tracer object.*",
|
||||
r"The __index__\(\) method was called on the JAX Tracer object.*",
|
||||
lambda: jit(f)(0))
|
||||
|
||||
def test_range_err(self):
|
||||
@ -670,7 +670,7 @@ class APITest(jtu.JaxTestCase):
|
||||
assert jit(f, static_argnums=(1,))(0, 5) == 10
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The __index__ method was called on the JAX Tracer object.*",
|
||||
r"The __index__\(\) method was called on the JAX Tracer object.*",
|
||||
lambda: jit(f)(0, 5))
|
||||
|
||||
def test_cast_int(self):
|
||||
@ -685,7 +685,7 @@ class APITest(jtu.JaxTestCase):
|
||||
f = lambda x: castfun(x)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The __index__ method was called on the JAX Tracer object.*", lambda: jit(f)(0))
|
||||
r"The __index__\(\) method was called on the JAX Tracer object.*", lambda: jit(f)(0))
|
||||
|
||||
def test_unimplemented_interpreter_rules(self):
|
||||
foo_p = Primitive('foo')
|
||||
|
@ -3228,7 +3228,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
self.assertRaises(TypeError, lambda: f(3., 3))
|
||||
self.assertRaises(jax.errors.TracerIntegerConversionError, lambda: f(3., 3))
|
||||
|
||||
@api.jit
|
||||
def g(x):
|
||||
@ -3237,7 +3237,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
self.assertRaises(TypeError, lambda: g(3.))
|
||||
self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.))
|
||||
|
||||
def testTracingPrimitiveWithNoTranslationErrorMessage(self):
|
||||
# TODO(mattjj): update this for jax3
|
||||
|
Loading…
x
Reference in New Issue
Block a user