mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
bf2abc886a
commit
1636d058df
@ -462,7 +462,8 @@ def full(shape, fill_value, dtype):
|
||||
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
|
||||
return FilledConstant(onp.asarray(fill_value, dtype), shape)
|
||||
elif isinstance(fill_value, xla.DeviceValue):
|
||||
return FilledConstant(convert_element_type(fill_value, dtype), shape)
|
||||
val = onp.asarray(fill_value, dtype)
|
||||
return FilledConstant(val, shape)
|
||||
else:
|
||||
return broadcast(convert_element_type(fill_value, dtype), shape)
|
||||
|
||||
|
@ -1187,7 +1187,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_n={}_increasing={}".format(
|
||||
jtu.format_shape_dtype_string([shape], dtype),
|
||||
@ -1207,6 +1206,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
def testIssue330(self):
|
||||
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
|
||||
self.assertEqual(x[0, 0], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user