1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

jnp.array no longer accepts None

PiperOrigin-RevId: 743291099
This commit is contained in:
Sergei Lebedev 2025-04-02 14:58:06 -07:00 committed by jax authors
parent e75b66463c
commit 9c58a112b3
3 changed files with 10 additions and 19 deletions

@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## Unreleased
* Breaking changes
* {func}`jax.numpy.array` no longer accepts `None`. This behavior was
deprecated since November 2023 and is now removed.
* Changes
* The minimum CuDNN version is v9.8.
* JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain

@ -5502,14 +5502,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None)
if any(leaf is None for leaf in leaves):
# Added Nov 16 2023
if deprecations.is_accelerated("jax-numpy-array-none"):
raise ValueError("None is not a valid value for jnp.array")
warnings.warn(
"None encountered in jnp.array(); this is currently treated as NaN. "
"In the future this will result in an error.",
FutureWarning, stacklevel=2)
leaves, treedef = tree_flatten(object)
raise ValueError("None is not a valid value for jnp.array")
leaves = [
leaf
if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None

@ -47,7 +47,6 @@ from jax.test_util import check_grads
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
@ -3796,16 +3795,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
jnp.array([0, val])
def testArrayNoneWarning(self):
if deprecations.is_accelerated('jax-numpy-array-none'):
ctx = self.assertRaisesRegex(
ValueError, 'None is not a valid value for jnp.array'
)
else:
ctx = self.assertWarnsRegex(
FutureWarning, r'None encountered in jnp.array\(\)'
)
with ctx:
def testArrayNone(self):
with self.assertRaisesRegex(
ValueError, 'None is not a valid value for jnp.array'
):
jnp.array([0.0, None])
def testIssue121(self):