Allow scalar numpy arrays as shapes in np.{zeros,ones,full}. (#1881)

This commit is contained in:
Peter Hawkins 2019-12-17 17:20:51 -05:00 committed by GitHub
parent 96677d9c6f
commit d8d3a7bc87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 3 deletions

View File

@ -1704,6 +1704,7 @@ def ones_like(x, dtype=None):
@_wraps(onp.full)
def full(shape, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full")
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, fill_value, dtype)
@ -1719,7 +1720,7 @@ def zeros(shape, dtype=None):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "zeros")
dtype = float_ if dtype is None else dtype
shape = (shape,) if isscalar(shape) else shape
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, 0, dtype)
@_wraps(onp.ones)
@ -1728,7 +1729,7 @@ def ones(shape, dtype=None):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "ones")
dtype = float_ if dtype is None else dtype
shape = (shape,) if isscalar(shape) else shape
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, 1, dtype)

View File

@ -344,6 +344,8 @@ def format_shape_dtype_string(shape, dtype):
return '{}[{}]'.format(dtype_str(dtype), shapestr)
elif type(shape) is int:
return '{}[{},]'.format(dtype_str(dtype), shape)
elif isinstance(shape, onp.ndarray):
return '{}[{}]'.format(dtype_str(dtype), shape)
else:
raise TypeError(type(shape))

View File

@ -1185,7 +1185,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
onp.dtype(out_dtype).name if out_dtype else "None"),
"shape": shape, "fill_value_dtype": fill_value_dtype,
"out_dtype": out_dtype, "rng_factory": jtu.rand_default}
for shape in array_shapes
for shape in array_shapes + [3, onp.array(7, dtype=onp.int32)]
for fill_value_dtype in default_dtypes
for out_dtype in [None] + default_dtypes))
def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory):
@ -1196,6 +1196,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype),
"onp_op": getattr(onp, op), "lnp_op": getattr(lnp, op),
"shape": shape, "dtype": dtype}
for op in ["zeros", "ones"]
for shape in [2, (), (2,), (3, 0), onp.array((4, 5, 6), dtype=onp.int32),
onp.array(4, dtype=onp.int32)]
for dtype in all_dtypes))
def testZerosOnes(self, onp_op, lnp_op, shape, dtype):
rng = jtu.rand_default()
def args_maker(): return []
onp_op = partial(onp_op, shape, dtype)
lnp_op = partial(lnp_op, shape, dtype)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format(
jtu.format_shape_dtype_string(shape, in_dtype),