mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve error when tracer is used as a list index
This commit is contained in:
parent
b21341d0ad
commit
56687e92e8
12
jax/core.py
12
jax/core.py
@ -493,9 +493,19 @@ class Tracer:
|
||||
"array indexing, like `x[idx]`, it may be that the array being "
|
||||
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
|
||||
"JAX Tracer instance; in that case, you can instead write "
|
||||
"`jax.device_put(x)[idx]`.")
|
||||
"`jnp.asarray(x)[idx]`.")
|
||||
raise Exception(msg)
|
||||
|
||||
def __index__(self):
|
||||
msg = (f"The __index__ method was called on the JAX Tracer object {self}.\n\n"
|
||||
"This error can occur when a JAX Tracer object is used in a context where "
|
||||
"a Python integer is expected, such as an argument to the range() function, "
|
||||
"or in index to a Python list. In the latter case, this can often be fixed "
|
||||
"by converting the indexed object to a JAX array, for example by changing "
|
||||
"`obj[idx]` to `jnp.asarray(obj)[idx]`."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
def __init__(self, trace: Trace):
|
||||
self._trace = trace
|
||||
|
||||
|
@ -650,6 +650,17 @@ class APITest(jtu.JaxTestCase):
|
||||
"Abstract tracer value"):
|
||||
jit(f)(1)
|
||||
|
||||
def test_list_index_err(self):
|
||||
L = [1, 2, 3]
|
||||
def f(n):
|
||||
return L[n]
|
||||
|
||||
assert jit(f, static_argnums=(0,))(0) == L[0]
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The __index__ method was called on the JAX Tracer object.*",
|
||||
lambda: jit(f)(0))
|
||||
|
||||
def test_range_err(self):
|
||||
def f(x, n):
|
||||
for i in range(n):
|
||||
@ -659,17 +670,22 @@ class APITest(jtu.JaxTestCase):
|
||||
assert jit(f, static_argnums=(1,))(0, 5) == 10
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
||||
"|Abstract value passed to .*)",
|
||||
"The __index__ method was called on the JAX Tracer object.*",
|
||||
lambda: jit(f)(0, 5))
|
||||
|
||||
def test_cast_int(self):
|
||||
f = lambda x: int(x)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
||||
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
|
||||
|
||||
def test_casts(self):
|
||||
for castfun in [hex, oct, int]:
|
||||
for castfun in [hex, oct]:
|
||||
f = lambda x: castfun(x)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
||||
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
|
||||
"The __index__ method was called on the JAX Tracer object.*", lambda: jit(f)(0))
|
||||
|
||||
def test_unimplemented_interpreter_rules(self):
|
||||
foo_p = Primitive('foo')
|
||||
|
Loading…
x
Reference in New Issue
Block a user