make jax.numpy.array(3) give 0D array, not scalar

the mechanism is to use lax.reshape (which was already there) and avoid
the optimization that skipped actually calling reshape_p.bind

fixes #121
This commit is contained in:
Matthew Johnson 2019-05-20 11:49:09 -07:00
parent adb15b7f4f
commit f8aa563db1
2 changed files with 5 additions and 2 deletions

View File

@ -514,7 +514,7 @@ def reshape(operand, new_sizes, dimensions=None):
"""
same_shape = onp.shape(operand) == tuple(new_sizes)
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
if same_shape and same_dims:
if onp.shape(operand) and same_shape and same_dims:
return operand
else:
return reshape_p.bind(

View File

@ -1051,7 +1051,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_arg{}".format(i), "arg": arg}
for i, arg in enumerate([
[1, 2, 3], [1., 2., 3.],
3., [1, 2, 3], [1., 2., 3.],
[[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
[[3, onp.array(2), 1], onp.arange(3.)],
])))
@ -1060,6 +1060,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)
def testIssue121(self):
assert not onp.isscalar(lnp.array(3))
def testArrayMethod(self):
class arraylike(object):
dtype = onp.float32