revise "Tracer with raw numpy" error message (#3160)

* revise "Tracer with raw numpy" error message

fixes #3133

* fix f-string typo

* fix typo

Co-authored-by: James Bradbury <jekbradbury@google.com>

Co-authored-by: James Bradbury <jekbradbury@google.com>
This commit is contained in:
Matthew Johnson 2020-05-20 19:09:44 -07:00 committed by GitHub
parent 12f26d3c8c
commit a4094f72a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 9 deletions

View File

@ -371,11 +371,18 @@ class Tracer(object):
__slots__ = ['_trace', '__weakref__']
def __array__(self, *args, **kw):
raise Exception("Tracer can't be used with raw numpy functions. "
"You might have\n"
" import numpy as np\n"
"instead of\n"
" import jax.numpy as jnp")
msg = ("The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {self}.\n\n"
"This error can occur when a JAX Tracer object is passed to a raw "
"numpy function, or a method on a numpy.ndarray object. You might "
"want to check that you are using `jnp` together with "
"`import jax.numpy as jnp` rather than using `np` via "
"`import numpy as np`. If this error arises on a line that involves "
"array indexing, like `x[idx]`, it may be that the array being "
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
"JAX Tracer instance; in that case, you can instead write "
"`jax.device_put(x)[idx]`.")
raise Exception(msg)
def __init__(self, trace):
self._trace = trace

View File

@ -177,10 +177,8 @@ class APITest(jtu.JaxTestCase):
def f(x):
return np.exp(x)
jtu.check_raises(lambda: grad(f)(np.zeros(3)), Exception,
"Tracer can't be used with raw numpy functions. "
"You might have\n import numpy as np\ninstead of\n"
" import jax.numpy as jnp")
with self.assertRaisesRegex(Exception, "The numpy.ndarray conversion .*"):
grad(f)(np.zeros(3))
def test_binop_mismatch(self):
def f(x, y):