Finalize deprecation of lax.linalg positional args

PiperOrigin-RevId: 629581163
This commit is contained in:
Jake VanderPlas 2024-04-30 17:55:31 -07:00 committed by jax authors
parent f998122774
commit eced12d89b
2 changed files with 5 additions and 54 deletions

View File

@ -60,6 +60,8 @@ Remember to align the itemized text with the first line of an item within a list
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
positional-only, following deprecation of the keywords in JAX v0.4.21.
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
specified by keyword. Previously, this raised a DeprecationWarning.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.

View File

@ -14,12 +14,10 @@
from __future__ import annotations
import inspect
import functools
from functools import partial
import math
from typing import cast, Any, Callable, Literal, TypeVar, overload
import warnings
from typing import Any, Callable, Literal, TypeVar, overload
import numpy as np
@ -62,51 +60,6 @@ TFun = TypeVar('TFun', bound=Callable[..., Any])
# traceables
# TODO(phawkins): remove backward compatibility shim after 2022/08/11.
def _warn_on_positional_kwargs(f: TFun) -> TFun:
"""Decorator used for backward compatibility of keyword-only arguments.
Some functions were changed to mark their keyword arguments as keyword-only.
This decorator allows existing code to keep working temporarily, while issuing
a warning if a now keyword-only parameter is passed positionally."""
sig = inspect.signature(f)
pos_names = [name for name, p in sig.parameters.items()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
kwarg_names = [name for name, p in sig.parameters.items()
if p.kind == inspect.Parameter.KEYWORD_ONLY]
# This decorator assumes that all arguments to `f` are either
# positional-or-keyword or keyword-only.
assert len(pos_names) + len(kwarg_names) == len(sig.parameters)
@functools.wraps(f)
def wrapped(*args, **kwargs):
if len(args) < len(pos_names):
a = pos_names[len(args)]
raise TypeError(f"{f.__name__} missing required positional argument: {a}")
pos_args = args[:len(pos_names)]
extra_kwargs = args[len(pos_names):]
if len(extra_kwargs) > len(kwarg_names):
raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} "
f" arguments but {len(args)} were given.")
for name, value in zip(kwarg_names, extra_kwargs):
if name in kwargs:
raise TypeError(f"{f.__name__} got multiple values for argument: "
f"{name}")
warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only "
"argument. Support for passing it positionally will be "
"removed in an upcoming JAX release.",
DeprecationWarning)
kwargs[name] = value
return f(*pos_args, **kwargs)
return cast(TFun, wrapped)
@_warn_on_positional_kwargs
def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
"""Cholesky decomposition.
@ -136,7 +89,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
x = symmetrize(x)
return jnp.tril(cholesky_p.bind(x))
@_warn_on_positional_kwargs
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True) -> list[Array]:
"""Eigendecomposition of a general matrix.
@ -162,7 +115,6 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors=compute_right_eigenvectors)
@_warn_on_positional_kwargs
def eigh(
x: Array,
*,
@ -267,7 +219,7 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
lu, pivots, permutation = lu_p.bind(x)
return lu, pivots, permutation
@_warn_on_positional_kwargs
def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
"""QR decomposition.
@ -333,7 +285,6 @@ def svd(
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
@_warn_on_positional_kwargs
def svd(
x: ArrayLike,
*,
@ -361,7 +312,6 @@ def svd(
return s
@_warn_on_positional_kwargs
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
left_side: bool = False, lower: bool = False,
transpose_a: bool = False, conjugate_a: bool = False,
@ -2208,7 +2158,6 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
# Schur Decomposition
@_warn_on_positional_kwargs
def schur(x: ArrayLike, *,
compute_schur_vectors: bool = True,
sort_eig_vals: bool = False,