mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove deprecated sym_pos
argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
This commit is contained in:
parent
3ee506d09a
commit
340e655ac2
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.21
|
## jax 0.4.21
|
||||||
|
|
||||||
|
* Deprecations
|
||||||
|
* The previously-deprecated `sym_pos` argument has been removed from
|
||||||
|
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
|
||||||
|
|
||||||
## jaxlib 0.4.21
|
## jaxlib 0.4.21
|
||||||
|
|
||||||
## jax 0.4.20 (Nov 2, 2023)
|
## jax 0.4.20 (Nov 2, 2023)
|
||||||
|
@ -18,7 +18,6 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
|
||||||
from typing import cast, overload, Any, Literal, Optional, Union
|
from typing import cast, overload, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
@ -353,20 +352,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
|
|||||||
|
|
||||||
|
|
||||||
@_wraps(scipy.linalg.solve,
|
@_wraps(scipy.linalg.solve,
|
||||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
lax_description=_no_overwrite_and_chkfinite_doc,
|
||||||
def solve(a: ArrayLike, b: ArrayLike, sym_pos: bool = False, lower: bool = False,
|
skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
||||||
|
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
|
||||||
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
|
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
|
||||||
check_finite: bool = True, assume_a: str = 'gen') -> Array:
|
check_finite: bool = True, assume_a: str = 'gen') -> Array:
|
||||||
# TODO(jakevdp) remove sym_pos argument after October 2022
|
|
||||||
del overwrite_a, overwrite_b, debug, check_finite #unused
|
del overwrite_a, overwrite_b, debug, check_finite #unused
|
||||||
valid_assume_a = ['gen', 'sym', 'her', 'pos']
|
valid_assume_a = ['gen', 'sym', 'her', 'pos']
|
||||||
if assume_a not in valid_assume_a:
|
if assume_a not in valid_assume_a:
|
||||||
raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
|
raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
|
||||||
if sym_pos:
|
|
||||||
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
|
|
||||||
"in a future JAX release. Use assume_a='pos' instead.",
|
|
||||||
category=FutureWarning, stacklevel=2)
|
|
||||||
assume_a = 'pos'
|
|
||||||
return _solve(a, b, assume_a, lower)
|
return _solve(a, b, assume_a, lower)
|
||||||
|
|
||||||
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user