mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate passing of None to jax.numpy.array
This commit is contained in:
parent
1fbcb24ec0
commit
84aa7e5c53
@ -14,6 +14,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations
|
||||
* The previously-deprecated `sym_pos` argument has been removed from
|
||||
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
|
||||
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
|
||||
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
|
||||
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
||||
|
||||
|
||||
## jaxlib 0.4.21
|
||||
|
||||
|
@ -110,6 +110,8 @@ set_printoptions = np.set_printoptions
|
||||
|
||||
@util._wraps(np.iscomplexobj)
|
||||
def iscomplexobj(x: Any) -> bool:
|
||||
if x is None:
|
||||
return False
|
||||
try:
|
||||
typ = x.dtype.type
|
||||
except AttributeError:
|
||||
@ -2085,7 +2087,14 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
object = object.__jax_array__()
|
||||
object = tree_map(lambda leaf: leaf.__jax_array__()
|
||||
if hasattr(leaf, "__jax_array__") else leaf, object)
|
||||
leaves = tree_leaves(object)
|
||||
leaves = tree_leaves(object, is_leaf=lambda x: x is None)
|
||||
if any(leaf is None for leaf in leaves):
|
||||
# Added Nov 16 2023
|
||||
warnings.warn(
|
||||
"None encountered in jnp.array(); this is currently treated as NaN. "
|
||||
"In the future this will result in an error.",
|
||||
FutureWarning, stacklevel=2)
|
||||
leaves = tree_leaves(object)
|
||||
if dtype is None:
|
||||
# Use lattice_result_type rather than result_type to avoid canonicalization.
|
||||
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
|
||||
|
@ -3218,6 +3218,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
|
||||
jnp.array([0, val])
|
||||
|
||||
def testArrayNoneWarning(self):
|
||||
# TODO(jakevdp): make this an error after the deprecation period.
|
||||
with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"):
|
||||
jnp.array([0.0, None])
|
||||
|
||||
def testIssue121(self):
|
||||
assert not np.isscalar(jnp.array(3))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user