mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make jax.numpy.array(3) give 0D array, not scalar
the mechanism is to use lax.reshape (which was already there) and avoid the optimization that skipped actually calling reshape_p.bind fixes #121
This commit is contained in:
parent
adb15b7f4f
commit
f8aa563db1
@ -514,7 +514,7 @@ def reshape(operand, new_sizes, dimensions=None):
|
||||
"""
|
||||
same_shape = onp.shape(operand) == tuple(new_sizes)
|
||||
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
|
||||
if same_shape and same_dims:
|
||||
if onp.shape(operand) and same_shape and same_dims:
|
||||
return operand
|
||||
else:
|
||||
return reshape_p.bind(
|
||||
|
@ -1051,7 +1051,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_arg{}".format(i), "arg": arg}
|
||||
for i, arg in enumerate([
|
||||
[1, 2, 3], [1., 2., 3.],
|
||||
3., [1, 2, 3], [1., 2., 3.],
|
||||
[[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
|
||||
[[3, onp.array(2), 1], onp.arange(3.)],
|
||||
])))
|
||||
@ -1060,6 +1060,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)
|
||||
|
||||
def testIssue121(self):
|
||||
assert not onp.isscalar(lnp.array(3))
|
||||
|
||||
def testArrayMethod(self):
|
||||
class arraylike(object):
|
||||
dtype = onp.float32
|
||||
|
Loading…
x
Reference in New Issue
Block a user