diff --git a/CHANGELOG.md b/CHANGELOG.md index 271a8f597..c353f1cde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * The previously-deprecated `sym_pos` argument has been removed from {func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead. + * Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or + within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`. + It currently is converted to NaN, and in the future will raise a {obj}`TypeError`. + ## jaxlib 0.4.21 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b0b266a44..bf3c9adc8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -110,6 +110,8 @@ set_printoptions = np.set_printoptions @util._wraps(np.iscomplexobj) def iscomplexobj(x: Any) -> bool: + if x is None: + return False try: typ = x.dtype.type except AttributeError: @@ -2085,7 +2087,14 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = object.__jax_array__() object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object) + leaves = tree_leaves(object, is_leaf=lambda x: x is None) + if any(leaf is None for leaf in leaves): + # Added Nov 16 2023 + warnings.warn( + "None encountered in jnp.array(); this is currently treated as NaN. " + "In the future this will result in an error.", + FutureWarning, stacklevel=2) + leaves = tree_leaves(object) if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 171ac75bd..51cf262eb 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3218,6 +3218,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): with self.assertRaisesRegex(OverflowError, "Python int too large.*"): jnp.array([0, val]) + def testArrayNoneWarning(self): + # TODO(jakevdp): make this an error after the deprecation period. + with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"): + jnp.array([0.0, None]) + def testIssue121(self): assert not np.isscalar(jnp.array(3))