mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate passing of None to jax.numpy.array
This commit is contained in:
parent
1fbcb24ec0
commit
84aa7e5c53
@ -14,6 +14,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Deprecations
|
* Deprecations
|
||||||
* The previously-deprecated `sym_pos` argument has been removed from
|
* The previously-deprecated `sym_pos` argument has been removed from
|
||||||
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
|
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
|
||||||
|
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
|
||||||
|
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
|
||||||
|
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
||||||
|
|
||||||
|
|
||||||
## jaxlib 0.4.21
|
## jaxlib 0.4.21
|
||||||
|
|
||||||
|
@ -110,6 +110,8 @@ set_printoptions = np.set_printoptions
|
|||||||
|
|
||||||
@util._wraps(np.iscomplexobj)
|
@util._wraps(np.iscomplexobj)
|
||||||
def iscomplexobj(x: Any) -> bool:
|
def iscomplexobj(x: Any) -> bool:
|
||||||
|
if x is None:
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
typ = x.dtype.type
|
typ = x.dtype.type
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -2085,7 +2087,14 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
|||||||
object = object.__jax_array__()
|
object = object.__jax_array__()
|
||||||
object = tree_map(lambda leaf: leaf.__jax_array__()
|
object = tree_map(lambda leaf: leaf.__jax_array__()
|
||||||
if hasattr(leaf, "__jax_array__") else leaf, object)
|
if hasattr(leaf, "__jax_array__") else leaf, object)
|
||||||
leaves = tree_leaves(object)
|
leaves = tree_leaves(object, is_leaf=lambda x: x is None)
|
||||||
|
if any(leaf is None for leaf in leaves):
|
||||||
|
# Added Nov 16 2023
|
||||||
|
warnings.warn(
|
||||||
|
"None encountered in jnp.array(); this is currently treated as NaN. "
|
||||||
|
"In the future this will result in an error.",
|
||||||
|
FutureWarning, stacklevel=2)
|
||||||
|
leaves = tree_leaves(object)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
# Use lattice_result_type rather than result_type to avoid canonicalization.
|
# Use lattice_result_type rather than result_type to avoid canonicalization.
|
||||||
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
|
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
|
||||||
|
@ -3218,6 +3218,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
|
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
|
||||||
jnp.array([0, val])
|
jnp.array([0, val])
|
||||||
|
|
||||||
|
def testArrayNoneWarning(self):
|
||||||
|
# TODO(jakevdp): make this an error after the deprecation period.
|
||||||
|
with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"):
|
||||||
|
jnp.array([0.0, None])
|
||||||
|
|
||||||
def testIssue121(self):
|
def testIssue121(self):
|
||||||
assert not np.isscalar(jnp.array(3))
|
assert not np.isscalar(jnp.array(3))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user