mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0
This commit is contained in:
parent
226cef08bf
commit
9090dd179d
@ -26,6 +26,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* {func}`jax.tree_structure` is deprecated in favor of {func}`jax.tree_util.tree_structure`
|
* {func}`jax.tree_structure` is deprecated in favor of {func}`jax.tree_util.tree_structure`
|
||||||
* {func}`jax.tree_transpose` is deprecated in favor of {func}`jax.tree_util.tree_transpose`
|
* {func}`jax.tree_transpose` is deprecated in favor of {func}`jax.tree_util.tree_transpose`
|
||||||
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
|
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
|
||||||
|
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
|
||||||
|
following a similar deprecation in {func}`scipy.linalg.solve`.
|
||||||
|
|
||||||
## jaxlib 0.3.15 (Unreleased)
|
## jaxlib 0.3.15 (Unreleased)
|
||||||
|
|
||||||
|
@ -205,9 +205,9 @@ def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False,
|
|||||||
return _qr(a, mode, pivoting)
|
return _qr(a, mode, pivoting)
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=('sym_pos', 'lower'))
|
@partial(jit, static_argnames=('assume_a', 'lower'))
|
||||||
def _solve(a, b, sym_pos, lower):
|
def _solve(a, b, assume_a, lower):
|
||||||
if not sym_pos:
|
if assume_a != 'pos':
|
||||||
return np_linalg.solve(a, b)
|
return np_linalg.solve(a, b)
|
||||||
|
|
||||||
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
||||||
@ -232,9 +232,18 @@ def _solve(a, b, sym_pos, lower):
|
|||||||
@_wraps(scipy.linalg.solve,
|
@_wraps(scipy.linalg.solve,
|
||||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
|
||||||
def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False,
|
def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False,
|
||||||
debug=False, check_finite=True):
|
debug=False, check_finite=True, assume_a='gen'):
|
||||||
|
# TODO(jakevdp) remove sym_pos argument after October 2022
|
||||||
del overwrite_a, overwrite_b, debug, check_finite
|
del overwrite_a, overwrite_b, debug, check_finite
|
||||||
return _solve(a, b, sym_pos, lower)
|
valid_assume_a = ['gen', 'sym', 'her', 'pos']
|
||||||
|
if assume_a not in valid_assume_a:
|
||||||
|
raise ValueError("Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
|
||||||
|
if sym_pos:
|
||||||
|
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
|
||||||
|
"in a future JAX release. Use assume_a='pos' instead.",
|
||||||
|
category=FutureWarning, stacklevel=2)
|
||||||
|
assume_a = 'pos'
|
||||||
|
return _solve(a, b, assume_a, lower)
|
||||||
|
|
||||||
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
||||||
def _solve_triangular(a, b, trans, lower, unit_diagonal):
|
def _solve_triangular(a, b, trans, lower, unit_diagonal):
|
||||||
|
@ -512,7 +512,7 @@ def _lstsq(a, b):
|
|||||||
# faster than jsp.linalg.lstsq
|
# faster than jsp.linalg.lstsq
|
||||||
a2 = _dot(a.T.conj(), a)
|
a2 = _dot(a.T.conj(), a)
|
||||||
b2 = _dot(a.T.conj(), b)
|
b2 = _dot(a.T.conj(), b)
|
||||||
return jsp.linalg.solve(a2, b2, sym_pos=True)
|
return jsp.linalg.solve(a2, b2, assume_a='pos')
|
||||||
|
|
||||||
|
|
||||||
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||||
|
@ -1153,31 +1153,31 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name":
|
{"testcase_name":
|
||||||
"_lhs={}_rhs={}_sym_pos={}_lower={}".format(
|
"_lhs={}_rhs={}_assume_a={}_lower={}".format(
|
||||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||||
sym_pos, lower),
|
assume_a, lower),
|
||||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||||
"sym_pos": sym_pos, "lower": lower}
|
"assume_a": assume_a, "lower": lower}
|
||||||
for lhs_shape, rhs_shape in [
|
for lhs_shape, rhs_shape in [
|
||||||
((1, 1), (1, 1)),
|
((1, 1), (1, 1)),
|
||||||
((4, 4), (4,)),
|
((4, 4), (4,)),
|
||||||
((8, 8), (8, 4)),
|
((8, 8), (8, 4)),
|
||||||
]
|
]
|
||||||
for sym_pos, lower in [
|
for assume_a, lower in [
|
||||||
(False, False),
|
('gen', False),
|
||||||
(True, False),
|
('pos', False),
|
||||||
(True, True),
|
('pos', True),
|
||||||
]
|
]
|
||||||
for dtype in float_types + complex_types))
|
for dtype in float_types + complex_types))
|
||||||
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower):
|
def testSolve(self, lhs_shape, rhs_shape, dtype, assume_a, lower):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
||||||
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
||||||
|
|
||||||
def args_maker():
|
def args_maker():
|
||||||
a = rng(lhs_shape, dtype)
|
a = rng(lhs_shape, dtype)
|
||||||
if sym_pos:
|
if assume_a == 'pos':
|
||||||
a = np.matmul(a, np.conj(T(a)))
|
a = np.matmul(a, np.conj(T(a)))
|
||||||
a = np.tril(a) if lower else np.triu(a)
|
a = np.tril(a) if lower else np.triu(a)
|
||||||
return [a, rng(rhs_shape, dtype)]
|
return [a, rng(rhs_shape, dtype)]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user