jnp.linalg.solve: finalize deprecation of batched 1D solves

This commit is contained in:
Jake VanderPlas 2025-01-08 08:57:56 -08:00
parent 1cc07dd392
commit 051abafd6d
3 changed files with 19 additions and 25 deletions

View File

@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).
* {func}`jax.numpy.linalg.solve` no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
use `solve(a, b[..., None]).squeeze(-1)`.
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,

View File

@ -18,7 +18,6 @@ from collections.abc import Sequence
from functools import partial
import itertools
import math
import warnings
import numpy as np
import operator
@ -1336,17 +1335,19 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
if b.ndim == 1:
signature = "(m,m),(m)->(m)"
elif a.ndim == b.ndim + 1:
# Deprecation warning added 2024-02-06
warnings.warn("jnp.linalg.solve: batched 1D solves with b.ndim > 1 are deprecated, "
"and in the future will be treated as a batched 2D solve. "
"Use solve(a, b[..., None])[..., 0] to avoid this warning.",
category=FutureWarning)
signature = "(m,m),(m)->(m)"
else:
signature = "(m,m),(m,n)->(m,n)"
if a.ndim < 2:
raise ValueError(
f"left hand array must be at least two dimensional; got {a.shape=}")
# Check for invalid inputs that previously would have led to a batched 1D solve:
if (b.ndim > 1 and a.ndim == b.ndim + 1 and
a.shape[-1] == b.shape[-1] and a.shape[-1] != b.shape[-2]):
raise ValueError(
f"Invalid shapes for solve: {a.shape}, {b.shape}. Prior to JAX v0.5.0,"
" this would have been treated as a batched 1-dimensional solve."
" To recover this behavior, use solve(a, b[..., None]).squeeze(-1).")
signature = "(m,m),(m)->(m)" if b.ndim == 1 else "(m,m),(m,n)->(m,n)"
return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)

View File

@ -1107,7 +1107,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
],
dtype=float_types + complex_types,
)
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testSolve(self, lhs_shape, rhs_shape, dtype):
rng = jtu.rand_default(self.rng())
@ -1121,22 +1120,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)],
rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)]
)
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testSolveBroadcasting(self, lhs_shape, rhs_shape):
# Batched solve can involve some ambiguities; this test checks
# that we match NumPy's convention in all cases.
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')]
if jtu.numpy_version() >= (2, 0, 0):
# TODO(jakevdp) remove this condition after solve broadcast deprecation is finalized.
if len(rhs_shape) == 1 or (len(lhs_shape) != len(rhs_shape) + 1):
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
else: # numpy 1.X
# As of numpy 1.26.3, np.linalg.solve fails when this condition is not met.
if len(lhs_shape) == 2 or len(rhs_shape) > 1:
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
if jtu.numpy_version() >= (2, 0, 0): # NumPy 2.0 semantics
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
self._CompileAndCheck(jnp.linalg.solve, args_maker)
@jtu.sample_product(
@ -1312,11 +1303,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertAllClose(xc, grad_test_jc(xc))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
def testIssue1151(self):
rng = self.rng()
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32)
b = jnp.array(rng.randn(100, 3, 1), dtype=jnp.float32)
x = jnp.linalg.solve(A, b)
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)