Improve error when tracer is used as a list index

This commit is contained in:
Jake VanderPlas 2021-02-25 13:35:41 -08:00
parent b21341d0ad
commit 56687e92e8
2 changed files with 32 additions and 6 deletions

View File

@ -493,9 +493,19 @@ class Tracer:
"array indexing, like `x[idx]`, it may be that the array being "
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
"JAX Tracer instance; in that case, you can instead write "
"`jax.device_put(x)[idx]`.")
"`jnp.asarray(x)[idx]`.")
raise Exception(msg)
def __index__(self):
msg = (f"The __index__ method was called on the JAX Tracer object {self}.\n\n"
"This error can occur when a JAX Tracer object is used in a context where "
"a Python integer is expected, such as an argument to the range() function, "
"or in index to a Python list. In the latter case, this can often be fixed "
"by converting the indexed object to a JAX array, for example by changing "
"`obj[idx]` to `jnp.asarray(obj)[idx]`."
)
raise TypeError(msg)
def __init__(self, trace: Trace):
self._trace = trace

View File

@ -650,6 +650,17 @@ class APITest(jtu.JaxTestCase):
"Abstract tracer value"):
jit(f)(1)
def test_list_index_err(self):
L = [1, 2, 3]
def f(n):
return L[n]
assert jit(f, static_argnums=(0,))(0) == L[0]
self.assertRaisesRegex(
TypeError,
"The __index__ method was called on the JAX Tracer object.*",
lambda: jit(f)(0))
def test_range_err(self):
def f(x, n):
for i in range(n):
@ -659,17 +670,22 @@ class APITest(jtu.JaxTestCase):
assert jit(f, static_argnums=(1,))(0, 5) == 10
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract value passed to .*)",
"The __index__ method was called on the JAX Tracer object.*",
lambda: jit(f)(0, 5))
def test_cast_int(self):
f = lambda x: int(x)
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
def test_casts(self):
for castfun in [hex, oct, int]:
for castfun in [hex, oct]:
f = lambda x: castfun(x)
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
"The __index__ method was called on the JAX Tracer object.*", lambda: jit(f)(0))
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')