mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
jnp.array
no longer accepts None
PiperOrigin-RevId: 743291099
This commit is contained in:
parent
e75b66463c
commit
9c58a112b3
@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## Unreleased
|
||||
|
||||
* Breaking changes
|
||||
|
||||
* {func}`jax.numpy.array` no longer accepts `None`. This behavior was
|
||||
deprecated since November 2023 and is now removed.
|
||||
|
||||
* Changes
|
||||
* The minimum CuDNN version is v9.8.
|
||||
* JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
|
||||
|
@ -5502,14 +5502,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
|
||||
leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None)
|
||||
if any(leaf is None for leaf in leaves):
|
||||
# Added Nov 16 2023
|
||||
if deprecations.is_accelerated("jax-numpy-array-none"):
|
||||
raise ValueError("None is not a valid value for jnp.array")
|
||||
warnings.warn(
|
||||
"None encountered in jnp.array(); this is currently treated as NaN. "
|
||||
"In the future this will result in an error.",
|
||||
FutureWarning, stacklevel=2)
|
||||
leaves, treedef = tree_flatten(object)
|
||||
raise ValueError("None is not a valid value for jnp.array")
|
||||
leaves = [
|
||||
leaf
|
||||
if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None
|
||||
|
@ -47,7 +47,6 @@ from jax.test_util import check_grads
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -3796,16 +3795,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
|
||||
jnp.array([0, val])
|
||||
|
||||
def testArrayNoneWarning(self):
|
||||
if deprecations.is_accelerated('jax-numpy-array-none'):
|
||||
ctx = self.assertRaisesRegex(
|
||||
ValueError, 'None is not a valid value for jnp.array'
|
||||
)
|
||||
else:
|
||||
ctx = self.assertWarnsRegex(
|
||||
FutureWarning, r'None encountered in jnp.array\(\)'
|
||||
)
|
||||
with ctx:
|
||||
def testArrayNone(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'None is not a valid value for jnp.array'
|
||||
):
|
||||
jnp.array([0.0, None])
|
||||
|
||||
def testIssue121(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user