mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:26:04 +00:00
Merge pull request #4415 from jakevdp:indices-concrete
PiperOrigin-RevId: 334683008
This commit is contained in:
commit
768b0c1eb7
@ -2693,7 +2693,7 @@ def ix_(*args):
|
||||
|
||||
@_wraps(np.indices)
|
||||
def indices(dimensions, dtype=int32, sparse=False):
|
||||
dimensions = tuple(dimensions)
|
||||
dimensions = tuple(core.concrete_or_error(int, d, "dimensions argument of jnp.indices") for d in dimensions)
|
||||
N = len(dimensions)
|
||||
output = []
|
||||
s = dimensions
|
||||
@ -2703,7 +2703,7 @@ def indices(dimensions, dtype=int32, sparse=False):
|
||||
s = (1,)*i + (dim,) + (1,)*(N - i - 1)
|
||||
output.append(lax.broadcast_in_dim(idx, s, (i,)))
|
||||
if sparse:
|
||||
return tuple(output)
|
||||
return tuple(output)
|
||||
return stack(output, 0) if output else array([], dtype=dtype)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user