diff --git a/CHANGELOG.md b/CHANGELOG.md index 246bfab5d..c623b212c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.21 +* Deprecations + * The previously-deprecated `sym_pos` argument has been removed from + {func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead. + ## jaxlib 0.4.21 ## jax 0.4.20 (Nov 2, 2023) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 6f6f96473..a7f93f9fc 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -18,7 +18,6 @@ from functools import partial import numpy as np import scipy.linalg import textwrap -import warnings from typing import cast, overload, Any, Literal, Optional, Union import jax @@ -353,20 +352,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: @_wraps(scipy.linalg.solve, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) -def solve(a: ArrayLike, b: ArrayLike, sym_pos: bool = False, lower: bool = False, + lax_description=_no_overwrite_and_chkfinite_doc, + skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) +def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False, check_finite: bool = True, assume_a: str = 'gen') -> Array: - # TODO(jakevdp) remove sym_pos argument after October 2022 del overwrite_a, overwrite_b, debug, check_finite #unused valid_assume_a = ['gen', 'sym', 'her', 'pos'] if assume_a not in valid_assume_a: raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}") - if sym_pos: - warnings.warn("The sym_pos argument to solve() is deprecated and will be removed " - "in a future JAX release. Use assume_a='pos' instead.", - category=FutureWarning, stacklevel=2) - assume_a = 'pos' return _solve(a, b, assume_a, lower) @partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))