mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement np.ix_, for non-bool inputs.
This commit is contained in:
parent
bd389b7fcf
commit
ec685bf8ae
@ -549,13 +549,14 @@ def reshape(operand, new_sizes, dimensions=None):
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
"""
|
||||
same_shape = onp.shape(operand) == tuple(new_sizes)
|
||||
new_sizes = tuple(new_sizes)
|
||||
same_shape = onp.shape(operand) == new_sizes
|
||||
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
|
||||
if onp.shape(operand) and same_shape and same_dims:
|
||||
return operand
|
||||
else:
|
||||
return reshape_p.bind(
|
||||
operand, new_sizes=tuple(new_sizes),
|
||||
operand, new_sizes=new_sizes,
|
||||
dimensions=None if dimensions is None else tuple(dimensions),
|
||||
old_sizes=onp.shape(operand))
|
||||
|
||||
|
@ -1398,6 +1398,25 @@ logspace = onp.logspace
|
||||
geomspace = onp.geomspace
|
||||
meshgrid = onp.meshgrid
|
||||
|
||||
|
||||
@_wraps(onp.ix_)
|
||||
def ix_(*args):
|
||||
n = len(args)
|
||||
output = []
|
||||
for i, a in enumerate(args):
|
||||
a = asarray(a)
|
||||
if len(a.shape) != 1:
|
||||
msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
|
||||
raise ValueError(msg.format(a.shape))
|
||||
if _dtype(a) == bool_:
|
||||
raise NotImplementedError(
|
||||
"Boolean arguments to jax.numpy.ix_ are not implemented")
|
||||
shape = [1] * n
|
||||
shape[i] = a.shape[0]
|
||||
output.append(lax.reshape(a, shape))
|
||||
return tuple(output)
|
||||
|
||||
|
||||
@_wraps(onp.repeat)
|
||||
def repeat(a, repeats, axis=None):
|
||||
if not isscalar(repeats):
|
||||
|
@ -1453,6 +1453,22 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(lnp.nan_to_num, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes),
|
||||
"rng": jtu.rand_default(), "shapes": shapes, "dtypes": dtypes}
|
||||
for shapes, dtypes in (
|
||||
((), ()),
|
||||
(((7,),), (onp.float32,)),
|
||||
(((3,), (4,)), (onp.float32, onp.int32)),
|
||||
(((3,), (0,), (4,)), (onp.int32, onp.float32, onp.int32)),
|
||||
)))
|
||||
def testIx_(self, rng, shapes, dtypes):
|
||||
args_maker = lambda: [rng(shape, dtype)
|
||||
for shape, dtype in zip(shapes, dtypes)]
|
||||
self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(lnp.ix_, args_maker, check_dtypes=True)
|
||||
|
||||
def testIssue330(self):
|
||||
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
|
||||
self.assertEqual(x[0, 0], 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user