mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6822 from hawkinsp:take
PiperOrigin-RevId: 375794394
This commit is contained in:
commit
3b973ac04a
@ -4676,6 +4676,11 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
|
||||
index_dims = len(shape(indices))
|
||||
slice_sizes = list(shape(a))
|
||||
if slice_sizes[axis_idx] == 0:
|
||||
if indices.size != 0:
|
||||
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
|
||||
return a
|
||||
|
||||
slice_sizes[axis_idx] = _min(indices.size, 1)
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=tuple(
|
||||
|
@ -3933,6 +3933,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.array([], dtype=jnp.float32),
|
||||
jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)))
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
jnp.ones((2, 0, 4), dtype=jnp.float32),
|
||||
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32),
|
||||
axis=1))
|
||||
|
||||
with self.assertRaisesRegex(IndexError, "non-empty jnp.take"):
|
||||
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32),
|
||||
jnp.array([0], jnp.int32), axis=1)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_index={}_axis={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user