fix lax.full handling of DeviceConstant scalars

fixes #330
This commit is contained in:
Matthew Johnson 2019-02-06 09:23:34 -08:00
parent bf2abc886a
commit 1636d058df
2 changed files with 6 additions and 2 deletions

View File

@ -462,7 +462,8 @@ def full(shape, fill_value, dtype):
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
return FilledConstant(onp.asarray(fill_value, dtype), shape)
elif isinstance(fill_value, xla.DeviceValue):
return FilledConstant(convert_element_type(fill_value, dtype), shape)
val = onp.asarray(fill_value, dtype)
return FilledConstant(val, shape)
else:
return broadcast(convert_element_type(fill_value, dtype), shape)

View File

@ -1187,7 +1187,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}_increasing={}".format(
jtu.format_shape_dtype_string([shape], dtype),
@ -1207,6 +1206,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False)
def testIssue330(self):
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
self.assertEqual(x[0, 0], 1)
if __name__ == "__main__":
absltest.main()