Implement np.ix_, for non-bool inputs.

This commit is contained in:
Peter Hawkins 2019-06-17 17:08:27 -04:00
parent bd389b7fcf
commit ec685bf8ae
3 changed files with 38 additions and 2 deletions

View File

@ -549,13 +549,14 @@ def reshape(operand, new_sizes, dimensions=None):
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
"""
same_shape = onp.shape(operand) == tuple(new_sizes)
new_sizes = tuple(new_sizes)
same_shape = onp.shape(operand) == new_sizes
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
if onp.shape(operand) and same_shape and same_dims:
return operand
else:
return reshape_p.bind(
operand, new_sizes=tuple(new_sizes),
operand, new_sizes=new_sizes,
dimensions=None if dimensions is None else tuple(dimensions),
old_sizes=onp.shape(operand))

View File

@ -1398,6 +1398,25 @@ logspace = onp.logspace
geomspace = onp.geomspace
meshgrid = onp.meshgrid
@_wraps(onp.ix_)
def ix_(*args):
n = len(args)
output = []
for i, a in enumerate(args):
a = asarray(a)
if len(a.shape) != 1:
msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
raise ValueError(msg.format(a.shape))
if _dtype(a) == bool_:
raise NotImplementedError(
"Boolean arguments to jax.numpy.ix_ are not implemented")
shape = [1] * n
shape[i] = a.shape[0]
output.append(lax.reshape(a, shape))
return tuple(output)
@_wraps(onp.repeat)
def repeat(a, repeats, axis=None):
if not isscalar(repeats):

View File

@ -1453,6 +1453,22 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=True)
self._CompileAndCheck(lnp.nan_to_num, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes),
"rng": jtu.rand_default(), "shapes": shapes, "dtypes": dtypes}
for shapes, dtypes in (
((), ()),
(((7,),), (onp.float32,)),
(((3,), (4,)), (onp.float32, onp.int32)),
(((3,), (0,), (4,)), (onp.int32, onp.float32, onp.int32)),
)))
def testIx_(self, rng, shapes, dtypes):
args_maker = lambda: [rng(shape, dtype)
for shape, dtype in zip(shapes, dtypes)]
self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker,
check_dtypes=True)
self._CompileAndCheck(lnp.ix_, args_maker, check_dtypes=True)
def testIssue330(self):
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
self.assertEqual(x[0, 0], 1)