Merge pull request #8534 from jakevdp:array-dtype

PiperOrigin-RevId: 410106997
This commit is contained in:
jax authors 2021-11-15 16:21:00 -08:00
commit be751d1dd6
2 changed files with 42 additions and 1 deletions

View File

@ -3594,7 +3594,15 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0, *, device=None):
raise TypeError("Unexpected input type for array: {}".format(type(object)))
out = lax._convert_element_type(out, dtype, weak_type=weak_type)
if weak_type:
# Here we make a judgment call: we only return a weakly-typed array when obj
# itself is weakly typed. That ensures array(x) is a no-op whenever x is weak,
# but avoids introducing weak types with something like array([1, 2, 3])
out = lax._convert_element_type(out, dtype, weak_type=True)
else:
# If dtype is not specified, we use result_type(out). This ensures JIT invariance
# with, e.g. lists of scalars.
out = lax._convert_element_type(out, dtype or result_type(out))
if ndmin > ndim(out):
out = lax.broadcast(out, (1,) * (ndmin - ndim(out)))

View File

@ -3560,6 +3560,39 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
canonicalize_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*")
def testArrayDtypeInference(self):
def _check(obj, out_dtype, weak_type):
dtype_reference = np.array(obj, dtype=out_dtype)
out = jnp.array(obj)
self.assertDtypesMatch(out, dtype_reference)
self.assertEqual(dtypes.is_weakly_typed(out), weak_type)
out_jit = jnp.array(obj)
self.assertDtypesMatch(out_jit, dtype_reference)
self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type)
# Python scalars become 64-bit weak types.
_check(1, np.int64, True)
_check(1.0, np.float64, True)
_check(1.0j, np.complex128, True)
# Lists become strongly-typed defaults.
_check([1], jnp.int_, False)
_check([1.0], jnp.float_, False)
_check([1.0j], jnp.complex_, False)
# Lists of weakly-typed objects become strongly-typed defaults.
_check([jnp.array(1)], jnp.int_, False)
_check([jnp.array(1.0)], jnp.float_, False)
_check([jnp.array(1.0j)], jnp.complex_, False)
# Lists of strongly-typed objects maintain their strong type.
_check([jnp.int64(1)], np.int64, False)
_check([jnp.float64(1)], np.float64, False)
_check([jnp.complex128(1)], np.complex128, False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype)}", "dtype": dtype}
for dtype in all_dtypes))