jnp.unique: small implementation cleanup

This commit is contained in:
Jake VanderPlas 2023-11-09 14:24:43 -08:00
parent 1b79395d32
commit caf930e467

View File

@ -270,9 +270,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
else:
ind = nonzero(mask, size=size)[0]
result = aux[ind] if aux.size else aux
if fill_value is not None:
fill_value = asarray(fill_value, dtype=result.dtype)
if size is not None and fill_value is not None:
fill_value = asarray(fill_value, dtype=result.dtype)
if result.shape[0]:
valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim)))
result = where(valid, result, fill_value)