mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8534 from jakevdp:array-dtype
PiperOrigin-RevId: 410106997
This commit is contained in:
commit
be751d1dd6
@ -3594,7 +3594,15 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0, *, device=None):
|
||||
|
||||
raise TypeError("Unexpected input type for array: {}".format(type(object)))
|
||||
|
||||
out = lax._convert_element_type(out, dtype, weak_type=weak_type)
|
||||
if weak_type:
|
||||
# Here we make a judgment call: we only return a weakly-typed array when obj
|
||||
# itself is weakly typed. That ensures array(x) is a no-op whenever x is weak,
|
||||
# but avoids introducing weak types with something like array([1, 2, 3])
|
||||
out = lax._convert_element_type(out, dtype, weak_type=True)
|
||||
else:
|
||||
# If dtype is not specified, we use result_type(out). This ensures JIT invariance
|
||||
# with, e.g. lists of scalars.
|
||||
out = lax._convert_element_type(out, dtype or result_type(out))
|
||||
|
||||
if ndmin > ndim(out):
|
||||
out = lax.broadcast(out, (1,) * (ndmin - ndim(out)))
|
||||
|
@ -3560,6 +3560,39 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
canonicalize_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*")
|
||||
def testArrayDtypeInference(self):
|
||||
def _check(obj, out_dtype, weak_type):
|
||||
dtype_reference = np.array(obj, dtype=out_dtype)
|
||||
|
||||
out = jnp.array(obj)
|
||||
self.assertDtypesMatch(out, dtype_reference)
|
||||
self.assertEqual(dtypes.is_weakly_typed(out), weak_type)
|
||||
|
||||
out_jit = jnp.array(obj)
|
||||
self.assertDtypesMatch(out_jit, dtype_reference)
|
||||
self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type)
|
||||
|
||||
# Python scalars become 64-bit weak types.
|
||||
_check(1, np.int64, True)
|
||||
_check(1.0, np.float64, True)
|
||||
_check(1.0j, np.complex128, True)
|
||||
|
||||
# Lists become strongly-typed defaults.
|
||||
_check([1], jnp.int_, False)
|
||||
_check([1.0], jnp.float_, False)
|
||||
_check([1.0j], jnp.complex_, False)
|
||||
|
||||
# Lists of weakly-typed objects become strongly-typed defaults.
|
||||
_check([jnp.array(1)], jnp.int_, False)
|
||||
_check([jnp.array(1.0)], jnp.float_, False)
|
||||
_check([jnp.array(1.0j)], jnp.complex_, False)
|
||||
|
||||
# Lists of strongly-typed objects maintain their strong type.
|
||||
_check([jnp.int64(1)], np.int64, False)
|
||||
_check([jnp.float64(1)], np.float64, False)
|
||||
_check([jnp.complex128(1)], np.complex128, False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype)}", "dtype": dtype}
|
||||
for dtype in all_dtypes))
|
||||
|
Loading…
x
Reference in New Issue
Block a user