Fix output dtype for np.full when dtype=None.

This commit is contained in:
Peter Hawkins 2019-02-27 11:09:33 -08:00
parent 2dae120d54
commit 8686e4dd3f
2 changed files with 5 additions and 5 deletions

View File

@ -777,7 +777,7 @@ def shaped_identity(x):
return shaped_identity_p.bind(x, shape=x.shape)
def full(shape, fill_value, dtype):
def full(shape, fill_value, dtype=None):
try:
shape = tuple(map(int, shape))
except TypeError:
@ -788,6 +788,7 @@ def full(shape, fill_value, dtype):
if onp.shape(fill_value):
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
raise TypeError(msg.format(onp.shape(fill_value)))
dtype = dtype or _dtype(fill_value)
dtype = xla_bridge.canonicalize_dtype(dtype)
# For constants (defined as Python scalars, raw ndarrays, or DeviceValues),
@ -892,7 +893,7 @@ def full_like(x, fill_value, dtype=None, shape=None):
`fill_value`, similar to the output of np.full.
"""
shape = onp.shape(x) if shape is None else shape
out = full(shape, fill_value, dtype or _dtype(x))
out = full(shape, fill_value, dtype)
return tie_in(x, out)

View File

@ -781,16 +781,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
lnp_fun = getattr(lnp, op)
self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outdtype={}".format(
jtu.format_shape_dtype_string(shape, fill_value_dtype),
onp.dtype(out_dtype).name),
onp.dtype(out_dtype).name if out_dtype else "None"),
"shape": shape, "fill_value_dtype": fill_value_dtype,
"out_dtype": out_dtype, "rng": jtu.rand_default()}
for shape in array_shapes
for fill_value_dtype in default_dtypes
for out_dtype in default_dtypes))
for out_dtype in [None] + default_dtypes))
def testFull(self, shape, fill_value_dtype, out_dtype, rng):
onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)