Remove deprecated sym_pos argument from jax.scipy.linalg.solve

PiperOrigin-RevId: 580940755
This commit is contained in:
Jake VanderPlas 2023-11-09 09:52:53 -08:00 committed by jax authors
parent 3ee506d09a
commit 340e655ac2
2 changed files with 7 additions and 9 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.21 ## jax 0.4.21
* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
## jaxlib 0.4.21 ## jaxlib 0.4.21
## jax 0.4.20 (Nov 2, 2023) ## jax 0.4.20 (Nov 2, 2023)

View File

@ -18,7 +18,6 @@ from functools import partial
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
import textwrap import textwrap
import warnings
from typing import cast, overload, Any, Literal, Optional, Union from typing import cast, overload, Any, Literal, Optional, Union
import jax import jax
@ -353,20 +352,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
@_wraps(scipy.linalg.solve, @_wraps(scipy.linalg.solve,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) lax_description=_no_overwrite_and_chkfinite_doc,
def solve(a: ArrayLike, b: ArrayLike, sym_pos: bool = False, lower: bool = False, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
check_finite: bool = True, assume_a: str = 'gen') -> Array: check_finite: bool = True, assume_a: str = 'gen') -> Array:
# TODO(jakevdp) remove sym_pos argument after October 2022
del overwrite_a, overwrite_b, debug, check_finite #unused del overwrite_a, overwrite_b, debug, check_finite #unused
valid_assume_a = ['gen', 'sym', 'her', 'pos'] valid_assume_a = ['gen', 'sym', 'her', 'pos']
if assume_a not in valid_assume_a: if assume_a not in valid_assume_a:
raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}") raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
if sym_pos:
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
"in a future JAX release. Use assume_a='pos' instead.",
category=FutureWarning, stacklevel=2)
assume_a = 'pos'
return _solve(a, b, assume_a, lower) return _solve(a, b, assume_a, lower)
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal')) @partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))