mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jnp.unique: small implementation cleanup
This commit is contained in:
parent
1b79395d32
commit
caf930e467
@ -270,9 +270,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
else:
|
||||
ind = nonzero(mask, size=size)[0]
|
||||
result = aux[ind] if aux.size else aux
|
||||
if fill_value is not None:
|
||||
fill_value = asarray(fill_value, dtype=result.dtype)
|
||||
if size is not None and fill_value is not None:
|
||||
fill_value = asarray(fill_value, dtype=result.dtype)
|
||||
if result.shape[0]:
|
||||
valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim)))
|
||||
result = where(valid, result, fill_value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user