mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow scalar numpy arrays as shapes in np.{zeros,ones,full}. (#1881)
This commit is contained in:
parent
96677d9c6f
commit
d8d3a7bc87
@ -1704,6 +1704,7 @@ def ones_like(x, dtype=None):
|
||||
@_wraps(onp.full)
|
||||
def full(shape, fill_value, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "full")
|
||||
shape = (shape,) if ndim(shape) == 0 else shape
|
||||
return lax.full(shape, fill_value, dtype)
|
||||
|
||||
|
||||
@ -1719,7 +1720,7 @@ def zeros(shape, dtype=None):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
lax._check_user_dtype_supported(dtype, "zeros")
|
||||
dtype = float_ if dtype is None else dtype
|
||||
shape = (shape,) if isscalar(shape) else shape
|
||||
shape = (shape,) if ndim(shape) == 0 else shape
|
||||
return lax.full(shape, 0, dtype)
|
||||
|
||||
@_wraps(onp.ones)
|
||||
@ -1728,7 +1729,7 @@ def ones(shape, dtype=None):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
lax._check_user_dtype_supported(dtype, "ones")
|
||||
dtype = float_ if dtype is None else dtype
|
||||
shape = (shape,) if isscalar(shape) else shape
|
||||
shape = (shape,) if ndim(shape) == 0 else shape
|
||||
return lax.full(shape, 1, dtype)
|
||||
|
||||
|
||||
|
@ -344,6 +344,8 @@ def format_shape_dtype_string(shape, dtype):
|
||||
return '{}[{}]'.format(dtype_str(dtype), shapestr)
|
||||
elif type(shape) is int:
|
||||
return '{}[{},]'.format(dtype_str(dtype), shape)
|
||||
elif isinstance(shape, onp.ndarray):
|
||||
return '{}[{}]'.format(dtype_str(dtype), shape)
|
||||
else:
|
||||
raise TypeError(type(shape))
|
||||
|
||||
|
@ -1185,7 +1185,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
onp.dtype(out_dtype).name if out_dtype else "None"),
|
||||
"shape": shape, "fill_value_dtype": fill_value_dtype,
|
||||
"out_dtype": out_dtype, "rng_factory": jtu.rand_default}
|
||||
for shape in array_shapes
|
||||
for shape in array_shapes + [3, onp.array(7, dtype=onp.int32)]
|
||||
for fill_value_dtype in default_dtypes
|
||||
for out_dtype in [None] + default_dtypes))
|
||||
def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory):
|
||||
@ -1196,6 +1196,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype),
|
||||
"onp_op": getattr(onp, op), "lnp_op": getattr(lnp, op),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for op in ["zeros", "ones"]
|
||||
for shape in [2, (), (2,), (3, 0), onp.array((4, 5, 6), dtype=onp.int32),
|
||||
onp.array(4, dtype=onp.int32)]
|
||||
for dtype in all_dtypes))
|
||||
def testZerosOnes(self, onp_op, lnp_op, shape, dtype):
|
||||
rng = jtu.rand_default()
|
||||
def args_maker(): return []
|
||||
onp_op = partial(onp_op, shape, dtype)
|
||||
lnp_op = partial(lnp_op, shape, dtype)
|
||||
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format(
|
||||
jtu.format_shape_dtype_string(shape, in_dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user