jnp.linalg.solve: finalize deprecation of batched 1D solves

This commit is contained in:
Jake VanderPlas 2025-01-08 08:57:56 -08:00
parent 1cc07dd392
commit 051abafd6d
3 changed files with 19 additions and 25 deletions

View File

@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than * {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in `optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`). the case of many arguments ({jax-issue}`#25214`).
* {func}`jax.numpy.linalg.solve` no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
use `solve(a, b[..., None]).squeeze(-1)`.
* New Features * New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`, * {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,

View File

@ -18,7 +18,6 @@ from collections.abc import Sequence
from functools import partial from functools import partial
import itertools import itertools
import math import math
import warnings
import numpy as np import numpy as np
import operator import operator
@ -1336,17 +1335,19 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b) check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
if b.ndim == 1: if a.ndim < 2:
signature = "(m,m),(m)->(m)" raise ValueError(
elif a.ndim == b.ndim + 1: f"left hand array must be at least two dimensional; got {a.shape=}")
# Deprecation warning added 2024-02-06
warnings.warn("jnp.linalg.solve: batched 1D solves with b.ndim > 1 are deprecated, " # Check for invalid inputs that previously would have led to a batched 1D solve:
"and in the future will be treated as a batched 2D solve. " if (b.ndim > 1 and a.ndim == b.ndim + 1 and
"Use solve(a, b[..., None])[..., 0] to avoid this warning.", a.shape[-1] == b.shape[-1] and a.shape[-1] != b.shape[-2]):
category=FutureWarning) raise ValueError(
signature = "(m,m),(m)->(m)" f"Invalid shapes for solve: {a.shape}, {b.shape}. Prior to JAX v0.5.0,"
else: " this would have been treated as a batched 1-dimensional solve."
signature = "(m,m),(m,n)->(m,n)" " To recover this behavior, use solve(a, b[..., None]).squeeze(-1).")
signature = "(m,m),(m)->(m)" if b.ndim == 1 else "(m,m),(m,n)->(m,n)"
return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b) return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)

View File

@ -1107,7 +1107,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
], ],
dtype=float_types + complex_types, dtype=float_types + complex_types,
) )
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testSolve(self, lhs_shape, rhs_shape, dtype): def testSolve(self, lhs_shape, rhs_shape, dtype):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -1121,22 +1120,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)], lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)],
rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)] rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)]
) )
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testSolveBroadcasting(self, lhs_shape, rhs_shape): def testSolveBroadcasting(self, lhs_shape, rhs_shape):
# Batched solve can involve some ambiguities; this test checks # Batched solve can involve some ambiguities; this test checks
# that we match NumPy's convention in all cases. # that we match NumPy's convention in all cases.
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')] args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')]
if jtu.numpy_version() >= (2, 0, 0): # NumPy 2.0 semantics
if jtu.numpy_version() >= (2, 0, 0): self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
# TODO(jakevdp) remove this condition after solve broadcast deprecation is finalized.
if len(rhs_shape) == 1 or (len(lhs_shape) != len(rhs_shape) + 1):
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
else: # numpy 1.X
# As of numpy 1.26.3, np.linalg.solve fails when this condition is not met.
if len(lhs_shape) == 2 or len(rhs_shape) > 1:
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
self._CompileAndCheck(jnp.linalg.solve, args_maker) self._CompileAndCheck(jnp.linalg.solve, args_maker)
@jtu.sample_product( @jtu.sample_product(
@ -1312,11 +1303,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertAllClose(xc, grad_test_jc(xc)) self.assertAllClose(xc, grad_test_jc(xc))
@jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(category=FutureWarning, message="jnp.linalg.solve: batched")
def testIssue1151(self): def testIssue1151(self):
rng = self.rng() rng = self.rng()
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32) A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32) b = jnp.array(rng.randn(100, 3, 1), dtype=jnp.float32)
x = jnp.linalg.solve(A, b) x = jnp.linalg.solve(A, b)
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2) self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)