diff --git a/CHANGELOG.md b/CHANGELOG.md index 50cdc23f8..785d3da3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ Remember to align the itemized text with the first line of an item within a list {attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`. * The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now positional-only, following deprecation of the keywords in JAX v0.4.21. + * Non-array arguments to functions in {mod}`jax.lax.linalg` now must be + specified by keyword. Previously, this raised a DeprecationWarning. * Bug fixes * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 80162a204..175b4648b 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -14,12 +14,10 @@ from __future__ import annotations -import inspect import functools from functools import partial import math -from typing import cast, Any, Callable, Literal, TypeVar, overload -import warnings +from typing import Any, Callable, Literal, TypeVar, overload import numpy as np @@ -62,51 +60,6 @@ TFun = TypeVar('TFun', bound=Callable[..., Any]) # traceables -# TODO(phawkins): remove backward compatibility shim after 2022/08/11. -def _warn_on_positional_kwargs(f: TFun) -> TFun: - """Decorator used for backward compatibility of keyword-only arguments. - - Some functions were changed to mark their keyword arguments as keyword-only. - This decorator allows existing code to keep working temporarily, while issuing - a warning if a now keyword-only parameter is passed positionally.""" - sig = inspect.signature(f) - pos_names = [name for name, p in sig.parameters.items() - if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD] - kwarg_names = [name for name, p in sig.parameters.items() - if p.kind == inspect.Parameter.KEYWORD_ONLY] - - # This decorator assumes that all arguments to `f` are either - # positional-or-keyword or keyword-only. - assert len(pos_names) + len(kwarg_names) == len(sig.parameters) - - @functools.wraps(f) - def wrapped(*args, **kwargs): - if len(args) < len(pos_names): - a = pos_names[len(args)] - raise TypeError(f"{f.__name__} missing required positional argument: {a}") - - pos_args = args[:len(pos_names)] - extra_kwargs = args[len(pos_names):] - - if len(extra_kwargs) > len(kwarg_names): - raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} " - f" arguments but {len(args)} were given.") - - for name, value in zip(kwarg_names, extra_kwargs): - if name in kwargs: - raise TypeError(f"{f.__name__} got multiple values for argument: " - f"{name}") - - warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only " - "argument. Support for passing it positionally will be " - "removed in an upcoming JAX release.", - DeprecationWarning) - kwargs[name] = value - return f(*pos_args, **kwargs) - - return cast(TFun, wrapped) - -@_warn_on_positional_kwargs def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: """Cholesky decomposition. @@ -136,7 +89,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: x = symmetrize(x) return jnp.tril(cholesky_p.bind(x)) -@_warn_on_positional_kwargs + def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> list[Array]: """Eigendecomposition of a general matrix. @@ -162,7 +115,6 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors=compute_right_eigenvectors) -@_warn_on_positional_kwargs def eigh( x: Array, *, @@ -267,7 +219,7 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]: lu, pivots, permutation = lu_p.bind(x) return lu, pivots, permutation -@_warn_on_positional_kwargs + def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]: """QR decomposition. @@ -333,7 +285,6 @@ def svd( # TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD. -@_warn_on_positional_kwargs def svd( x: ArrayLike, *, @@ -361,7 +312,6 @@ def svd( return s -@_warn_on_positional_kwargs def triangular_solve(a: ArrayLike, b: ArrayLike, *, left_side: bool = False, lower: bool = False, transpose_a: bool = False, conjugate_a: bool = False, @@ -2208,7 +2158,6 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: # Schur Decomposition -@_warn_on_positional_kwargs def schur(x: ArrayLike, *, compute_schur_vectors: bool = True, sort_eig_vals: bool = False,