Remove deprecated sym_pos argument from jax.scipy.linalg.solve

PiperOrigin-RevId: 580940755
This commit is contained in:
Jake VanderPlas 2023-11-09 09:52:53 -08:00 committed by jax authors
parent 3ee506d09a
commit 340e655ac2
2 changed files with 7 additions and 9 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.21
* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
## jaxlib 0.4.21
## jax 0.4.20 (Nov 2, 2023)

View File

@ -18,7 +18,6 @@ from functools import partial
import numpy as np
import scipy.linalg
import textwrap
import warnings
from typing import cast, overload, Any, Literal, Optional, Union
import jax
@ -353,20 +352,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
@_wraps(scipy.linalg.solve,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
def solve(a: ArrayLike, b: ArrayLike, sym_pos: bool = False, lower: bool = False,
lax_description=_no_overwrite_and_chkfinite_doc,
skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
check_finite: bool = True, assume_a: str = 'gen') -> Array:
# TODO(jakevdp) remove sym_pos argument after October 2022
del overwrite_a, overwrite_b, debug, check_finite #unused
valid_assume_a = ['gen', 'sym', 'her', 'pos']
if assume_a not in valid_assume_a:
raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
if sym_pos:
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
"in a future JAX release. Use assume_a='pos' instead.",
category=FutureWarning, stacklevel=2)
assume_a = 'pos'
return _solve(a, b, assume_a, lower)
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))