mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of lax.linalg positional args
PiperOrigin-RevId: 629581163
This commit is contained in:
parent
f998122774
commit
eced12d89b
@ -60,6 +60,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
|
||||
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
|
||||
positional-only, following deprecation of the keywords in JAX v0.4.21.
|
||||
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
|
||||
specified by keyword. Previously, this raised a DeprecationWarning.
|
||||
|
||||
* Bug fixes
|
||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||
|
@ -14,12 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import cast, Any, Callable, Literal, TypeVar, overload
|
||||
import warnings
|
||||
from typing import Any, Callable, Literal, TypeVar, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -62,51 +60,6 @@ TFun = TypeVar('TFun', bound=Callable[..., Any])
|
||||
|
||||
# traceables
|
||||
|
||||
# TODO(phawkins): remove backward compatibility shim after 2022/08/11.
|
||||
def _warn_on_positional_kwargs(f: TFun) -> TFun:
|
||||
"""Decorator used for backward compatibility of keyword-only arguments.
|
||||
|
||||
Some functions were changed to mark their keyword arguments as keyword-only.
|
||||
This decorator allows existing code to keep working temporarily, while issuing
|
||||
a warning if a now keyword-only parameter is passed positionally."""
|
||||
sig = inspect.signature(f)
|
||||
pos_names = [name for name, p in sig.parameters.items()
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
||||
kwarg_names = [name for name, p in sig.parameters.items()
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY]
|
||||
|
||||
# This decorator assumes that all arguments to `f` are either
|
||||
# positional-or-keyword or keyword-only.
|
||||
assert len(pos_names) + len(kwarg_names) == len(sig.parameters)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
if len(args) < len(pos_names):
|
||||
a = pos_names[len(args)]
|
||||
raise TypeError(f"{f.__name__} missing required positional argument: {a}")
|
||||
|
||||
pos_args = args[:len(pos_names)]
|
||||
extra_kwargs = args[len(pos_names):]
|
||||
|
||||
if len(extra_kwargs) > len(kwarg_names):
|
||||
raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} "
|
||||
f" arguments but {len(args)} were given.")
|
||||
|
||||
for name, value in zip(kwarg_names, extra_kwargs):
|
||||
if name in kwargs:
|
||||
raise TypeError(f"{f.__name__} got multiple values for argument: "
|
||||
f"{name}")
|
||||
|
||||
warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only "
|
||||
"argument. Support for passing it positionally will be "
|
||||
"removed in an upcoming JAX release.",
|
||||
DeprecationWarning)
|
||||
kwargs[name] = value
|
||||
return f(*pos_args, **kwargs)
|
||||
|
||||
return cast(TFun, wrapped)
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
||||
"""Cholesky decomposition.
|
||||
|
||||
@ -136,7 +89,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
||||
x = symmetrize(x)
|
||||
return jnp.tril(cholesky_p.bind(x))
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
|
||||
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
||||
compute_right_eigenvectors: bool = True) -> list[Array]:
|
||||
"""Eigendecomposition of a general matrix.
|
||||
@ -162,7 +115,6 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def eigh(
|
||||
x: Array,
|
||||
*,
|
||||
@ -267,7 +219,7 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
lu, pivots, permutation = lu_p.bind(x)
|
||||
return lu, pivots, permutation
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
|
||||
def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
|
||||
"""QR decomposition.
|
||||
|
||||
@ -333,7 +285,6 @@ def svd(
|
||||
|
||||
|
||||
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
||||
@_warn_on_positional_kwargs
|
||||
def svd(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
@ -361,7 +312,6 @@ def svd(
|
||||
return s
|
||||
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
|
||||
left_side: bool = False, lower: bool = False,
|
||||
transpose_a: bool = False, conjugate_a: bool = False,
|
||||
@ -2208,7 +2158,6 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
||||
# Schur Decomposition
|
||||
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def schur(x: ArrayLike, *,
|
||||
compute_schur_vectors: bool = True,
|
||||
sort_eig_vals: bool = False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user