mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix output dtype for np.full when dtype=None.
This commit is contained in:
parent
2dae120d54
commit
8686e4dd3f
@ -777,7 +777,7 @@ def shaped_identity(x):
|
||||
return shaped_identity_p.bind(x, shape=x.shape)
|
||||
|
||||
|
||||
def full(shape, fill_value, dtype):
|
||||
def full(shape, fill_value, dtype=None):
|
||||
try:
|
||||
shape = tuple(map(int, shape))
|
||||
except TypeError:
|
||||
@ -788,6 +788,7 @@ def full(shape, fill_value, dtype):
|
||||
if onp.shape(fill_value):
|
||||
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
||||
raise TypeError(msg.format(onp.shape(fill_value)))
|
||||
dtype = dtype or _dtype(fill_value)
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
|
||||
# For constants (defined as Python scalars, raw ndarrays, or DeviceValues),
|
||||
@ -892,7 +893,7 @@ def full_like(x, fill_value, dtype=None, shape=None):
|
||||
`fill_value`, similar to the output of np.full.
|
||||
"""
|
||||
shape = onp.shape(x) if shape is None else shape
|
||||
out = full(shape, fill_value, dtype or _dtype(x))
|
||||
out = full(shape, fill_value, dtype)
|
||||
return tie_in(x, out)
|
||||
|
||||
|
||||
|
@ -781,16 +781,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
lnp_fun = getattr(lnp, op)
|
||||
self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outdtype={}".format(
|
||||
jtu.format_shape_dtype_string(shape, fill_value_dtype),
|
||||
onp.dtype(out_dtype).name),
|
||||
onp.dtype(out_dtype).name if out_dtype else "None"),
|
||||
"shape": shape, "fill_value_dtype": fill_value_dtype,
|
||||
"out_dtype": out_dtype, "rng": jtu.rand_default()}
|
||||
for shape in array_shapes
|
||||
for fill_value_dtype in default_dtypes
|
||||
for out_dtype in default_dtypes))
|
||||
for out_dtype in [None] + default_dtypes))
|
||||
def testFull(self, shape, fill_value_dtype, out_dtype, rng):
|
||||
onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
|
||||
lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user