Merge pull request #4415 from jakevdp:indices-concrete

PiperOrigin-RevId: 334683008
This commit is contained in:
jax authors 2020-09-30 14:44:47 -07:00
commit 768b0c1eb7

View File

@ -2693,7 +2693,7 @@ def ix_(*args):
@_wraps(np.indices)
def indices(dimensions, dtype=int32, sparse=False):
dimensions = tuple(dimensions)
dimensions = tuple(core.concrete_or_error(int, d, "dimensions argument of jnp.indices") for d in dimensions)
N = len(dimensions)
output = []
s = dimensions
@ -2703,7 +2703,7 @@ def indices(dimensions, dtype=int32, sparse=False):
s = (1,)*i + (dim,) + (1,)*(N - i - 1)
output.append(lax.broadcast_in_dim(idx, s, (i,)))
if sparse:
return tuple(output)
return tuple(output)
return stack(output, 0) if output else array([], dtype=dtype)