mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0
This commit is contained in:
parent
226cef08bf
commit
9090dd179d
@ -26,6 +26,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* {func}`jax.tree_structure` is deprecated in favor of {func}`jax.tree_util.tree_structure`
|
||||
* {func}`jax.tree_transpose` is deprecated in favor of {func}`jax.tree_util.tree_transpose`
|
||||
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
|
||||
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
|
||||
following a similar deprecation in {func}`scipy.linalg.solve`.
|
||||
|
||||
## jaxlib 0.3.15 (Unreleased)
|
||||
|
||||
|
@ -205,9 +205,9 @@ def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False,
|
||||
return _qr(a, mode, pivoting)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('sym_pos', 'lower'))
|
||||
def _solve(a, b, sym_pos, lower):
|
||||
if not sym_pos:
|
||||
@partial(jit, static_argnames=('assume_a', 'lower'))
|
||||
def _solve(a, b, assume_a, lower):
|
||||
if assume_a != 'pos':
|
||||
return np_linalg.solve(a, b)
|
||||
|
||||
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
||||
@ -232,9 +232,18 @@ def _solve(a, b, sym_pos, lower):
|
||||
@_wraps(scipy.linalg.solve,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
||||
def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False,
|
||||
debug=False, check_finite=True):
|
||||
debug=False, check_finite=True, assume_a='gen'):
|
||||
# TODO(jakevdp) remove sym_pos argument after October 2022
|
||||
del overwrite_a, overwrite_b, debug, check_finite
|
||||
return _solve(a, b, sym_pos, lower)
|
||||
valid_assume_a = ['gen', 'sym', 'her', 'pos']
|
||||
if assume_a not in valid_assume_a:
|
||||
raise ValueError("Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
|
||||
if sym_pos:
|
||||
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
|
||||
"in a future JAX release. Use assume_a='pos' instead.",
|
||||
category=FutureWarning, stacklevel=2)
|
||||
assume_a = 'pos'
|
||||
return _solve(a, b, assume_a, lower)
|
||||
|
||||
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
||||
def _solve_triangular(a, b, trans, lower, unit_diagonal):
|
||||
|
@ -512,7 +512,7 @@ def _lstsq(a, b):
|
||||
# faster than jsp.linalg.lstsq
|
||||
a2 = _dot(a.T.conj(), a)
|
||||
b2 = _dot(a.T.conj(), b)
|
||||
return jsp.linalg.solve(a2, b2, sym_pos=True)
|
||||
return jsp.linalg.solve(a2, b2, assume_a='pos')
|
||||
|
||||
|
||||
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||
|
@ -1153,31 +1153,31 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs={}_rhs={}_sym_pos={}_lower={}".format(
|
||||
"_lhs={}_rhs={}_assume_a={}_lower={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
sym_pos, lower),
|
||||
assume_a, lower),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"sym_pos": sym_pos, "lower": lower}
|
||||
"assume_a": assume_a, "lower": lower}
|
||||
for lhs_shape, rhs_shape in [
|
||||
((1, 1), (1, 1)),
|
||||
((4, 4), (4,)),
|
||||
((8, 8), (8, 4)),
|
||||
]
|
||||
for sym_pos, lower in [
|
||||
(False, False),
|
||||
(True, False),
|
||||
(True, True),
|
||||
for assume_a, lower in [
|
||||
('gen', False),
|
||||
('pos', False),
|
||||
('pos', True),
|
||||
]
|
||||
for dtype in float_types + complex_types))
|
||||
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower):
|
||||
def testSolve(self, lhs_shape, rhs_shape, dtype, assume_a, lower):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
||||
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
||||
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
||||
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
||||
|
||||
def args_maker():
|
||||
a = rng(lhs_shape, dtype)
|
||||
if sym_pos:
|
||||
if assume_a == 'pos':
|
||||
a = np.matmul(a, np.conj(T(a)))
|
||||
a = np.tril(a) if lower else np.triu(a)
|
||||
return [a, rng(rhs_shape, dtype)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user