Merge pull request #6822 from hawkinsp:take

PiperOrigin-RevId: 375794394
This commit is contained in:
jax authors 2021-05-25 14:22:07 -07:00
commit 3b973ac04a
2 changed files with 13 additions and 0 deletions

View File

@ -4676,6 +4676,11 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
index_dims = len(shape(indices))
slice_sizes = list(shape(a))
if slice_sizes[axis_idx] == 0:
if indices.size != 0:
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
return a
slice_sizes[axis_idx] = _min(indices.size, 1)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(

View File

@ -3933,6 +3933,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.array([], dtype=jnp.float32),
jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)))
np.testing.assert_array_equal(
jnp.ones((2, 0, 4), dtype=jnp.float32),
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32),
axis=1))
with self.assertRaisesRegex(IndexError, "non-empty jnp.take"):
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32),
jnp.array([0], jnp.int32), axis=1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}".format(