Deprecate passing of None to jax.numpy.array

This commit is contained in:
Jake VanderPlas 2023-11-16 15:10:56 -08:00
parent 1fbcb24ec0
commit 84aa7e5c53
3 changed files with 19 additions and 1 deletions

View File

@ -14,6 +14,10 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
## jaxlib 0.4.21

View File

@ -110,6 +110,8 @@ set_printoptions = np.set_printoptions
@util._wraps(np.iscomplexobj)
def iscomplexobj(x: Any) -> bool:
if x is None:
return False
try:
typ = x.dtype.type
except AttributeError:
@ -2085,6 +2087,13 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
object = object.__jax_array__()
object = tree_map(lambda leaf: leaf.__jax_array__()
if hasattr(leaf, "__jax_array__") else leaf, object)
leaves = tree_leaves(object, is_leaf=lambda x: x is None)
if any(leaf is None for leaf in leaves):
# Added Nov 16 2023
warnings.warn(
"None encountered in jnp.array(); this is currently treated as NaN. "
"In the future this will result in an error.",
FutureWarning, stacklevel=2)
leaves = tree_leaves(object)
if dtype is None:
# Use lattice_result_type rather than result_type to avoid canonicalization.

View File

@ -3218,6 +3218,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
jnp.array([0, val])
def testArrayNoneWarning(self):
# TODO(jakevdp): make this an error after the deprecation period.
with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"):
jnp.array([0.0, None])
def testIssue121(self):
assert not np.isscalar(jnp.array(3))