mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:16:05 +00:00
Improve errors for failed compilations w/ core.concrete_or_error
This commit is contained in:
parent
95e3fcdf90
commit
492889f4a4
@ -1459,6 +1459,8 @@ def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
|
||||
|
||||
@_wraps(np.intersect1d)
|
||||
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
||||
ar1 = core.concrete_or_error(asarray, ar1, "The error arose in intersect1d()")
|
||||
ar2 = core.concrete_or_error(asarray, ar2, "The error arose in intersect1d()")
|
||||
|
||||
if not assume_unique:
|
||||
if return_indices:
|
||||
@ -1558,6 +1560,9 @@ def bincount(x, weights=None, minlength=0, *, length=None):
|
||||
msg = f"x argument to bincount must have an integer type; got {x.dtype}"
|
||||
raise TypeError(msg)
|
||||
if length is None:
|
||||
x = core.concrete_or_error(array, x,
|
||||
"The error occured because of argument 'x' of jnp.bincount. "
|
||||
"To avoid this error, pass a static `length` argument.")
|
||||
length = max(x) + 1
|
||||
length = _max(length, minlength)
|
||||
if ndim(x) != 1:
|
||||
@ -1997,7 +2002,7 @@ because its output shape is data-dependent.
|
||||
def nonzero(a):
|
||||
# Note: this function cannot be jitted because its output has a dynamic
|
||||
# shape.
|
||||
a = atleast_1d(a)
|
||||
a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero")
|
||||
dims = shape(a)
|
||||
ndims = len(dims)
|
||||
ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
|
||||
@ -2778,7 +2783,10 @@ def repeat(a, repeats, axis=None, *, total_repeat_length=None):
|
||||
|
||||
# If total_repeat_length is not given, can't compile, use a default.
|
||||
if total_repeat_length is None:
|
||||
repeats = core.concrete_or_error(np.array, repeats, "It arose in jax.numpy.repeat.")
|
||||
repeats = core.concrete_or_error(np.array, repeats,
|
||||
"When jit-compiling jnp.repeat, the total number of repeats must be static. "
|
||||
"To fix this, either specify a static value for `repeats`, or pass a static "
|
||||
"value to `total_repeat_length`.")
|
||||
repeats = np.ravel(repeats)
|
||||
if ndim(a) != 0:
|
||||
repeats = np.broadcast_to(repeats, [a.shape[axis]])
|
||||
@ -3021,6 +3029,8 @@ def polyder(p, m=1):
|
||||
|
||||
@_wraps(np.trim_zeros)
|
||||
def trim_zeros(filt, trim='fb'):
|
||||
filt = core.concrete_or_error(asarray, filt,
|
||||
"Error arose in the `filt` argument of trim_zeros()")
|
||||
nz = asarray(filt) == 0
|
||||
if all(nz):
|
||||
return empty(0, _dtype(filt))
|
||||
@ -3808,6 +3818,7 @@ def _unique1d(ar, return_index=False, return_inverse=False,
|
||||
@_wraps(np.unique)
|
||||
def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis=None):
|
||||
ar = core.concrete_or_error(array, ar, "The error arose in jnp.unique()")
|
||||
|
||||
if iscomplexobj(ar):
|
||||
raise NotImplementedError(
|
||||
|
Loading…
x
Reference in New Issue
Block a user