mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove deprecated sym_pos
argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
This commit is contained in:
parent
3ee506d09a
commit
340e655ac2
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.21
|
||||
|
||||
* Deprecations
|
||||
* The previously-deprecated `sym_pos` argument has been removed from
|
||||
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
|
||||
|
||||
## jaxlib 0.4.21
|
||||
|
||||
## jax 0.4.20 (Nov 2, 2023)
|
||||
|
@ -18,7 +18,6 @@ from functools import partial
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import cast, overload, Any, Literal, Optional, Union
|
||||
|
||||
import jax
|
||||
@ -353,20 +352,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.solve,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
||||
def solve(a: ArrayLike, b: ArrayLike, sym_pos: bool = False, lower: bool = False,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc,
|
||||
skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
||||
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
|
||||
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
|
||||
check_finite: bool = True, assume_a: str = 'gen') -> Array:
|
||||
# TODO(jakevdp) remove sym_pos argument after October 2022
|
||||
del overwrite_a, overwrite_b, debug, check_finite #unused
|
||||
valid_assume_a = ['gen', 'sym', 'her', 'pos']
|
||||
if assume_a not in valid_assume_a:
|
||||
raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
|
||||
if sym_pos:
|
||||
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
|
||||
"in a future JAX release. Use assume_a='pos' instead.",
|
||||
category=FutureWarning, stacklevel=2)
|
||||
assume_a = 'pos'
|
||||
return _solve(a, b, assume_a, lower)
|
||||
|
||||
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user