Improve errors for failed compilations w/ core.concrete_or_error

This commit is contained in:
Jake VanderPlas 2020-10-02 15:08:21 -07:00
parent 95e3fcdf90
commit 492889f4a4

View File

@ -1459,6 +1459,8 @@ def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
@_wraps(np.intersect1d)
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
ar1 = core.concrete_or_error(asarray, ar1, "The error arose in intersect1d()")
ar2 = core.concrete_or_error(asarray, ar2, "The error arose in intersect1d()")
if not assume_unique:
if return_indices:
@ -1558,6 +1560,9 @@ def bincount(x, weights=None, minlength=0, *, length=None):
msg = f"x argument to bincount must have an integer type; got {x.dtype}"
raise TypeError(msg)
if length is None:
x = core.concrete_or_error(array, x,
"The error occured because of argument 'x' of jnp.bincount. "
"To avoid this error, pass a static `length` argument.")
length = max(x) + 1
length = _max(length, minlength)
if ndim(x) != 1:
@ -1997,7 +2002,7 @@ because its output shape is data-dependent.
def nonzero(a):
# Note: this function cannot be jitted because its output has a dynamic
# shape.
a = atleast_1d(a)
a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero")
dims = shape(a)
ndims = len(dims)
ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
@ -2778,7 +2783,10 @@ def repeat(a, repeats, axis=None, *, total_repeat_length=None):
# If total_repeat_length is not given, can't compile, use a default.
if total_repeat_length is None:
repeats = core.concrete_or_error(np.array, repeats, "It arose in jax.numpy.repeat.")
repeats = core.concrete_or_error(np.array, repeats,
"When jit-compiling jnp.repeat, the total number of repeats must be static. "
"To fix this, either specify a static value for `repeats`, or pass a static "
"value to `total_repeat_length`.")
repeats = np.ravel(repeats)
if ndim(a) != 0:
repeats = np.broadcast_to(repeats, [a.shape[axis]])
@ -3021,6 +3029,8 @@ def polyder(p, m=1):
@_wraps(np.trim_zeros)
def trim_zeros(filt, trim='fb'):
filt = core.concrete_or_error(asarray, filt,
"Error arose in the `filt` argument of trim_zeros()")
nz = asarray(filt) == 0
if all(nz):
return empty(0, _dtype(filt))
@ -3808,6 +3818,7 @@ def _unique1d(ar, return_index=False, return_inverse=False,
@_wraps(np.unique)
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis=None):
ar = core.concrete_or_error(array, ar, "The error arose in jnp.unique()")
if iscomplexobj(ar):
raise NotImplementedError(