Merge pull request #19224 from jakevdp:batched-solve

PiperOrigin-RevId: 596313955
This commit is contained in:
jax authors 2024-01-06 21:23:13 -08:00
commit 16699e4e78
2 changed files with 21 additions and 3 deletions

View File

@ -608,9 +608,13 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]:
def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
if a.ndim >= 2 and b.ndim > a.ndim:
a = lax.expand_dims(a, tuple(range(b.ndim - a.ndim)))
return lax_linalg._solve(a, b)
# TODO(jakevdp): this condition matches the broadcasting behavior in numpy < 2.0.
# For the array API specification, we would check only if b.ndim == 1.
if b.ndim == 1 or a.ndim == b.ndim + 1:
signature = "(m,m),(m)->(m)"
else:
signature = "(m,m),(m,n)->(m,n)"
return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,

View File

@ -1020,6 +1020,20 @@ class NumpyLinalgTest(jtu.JaxTestCase):
tol=1e-3)
self._CompileAndCheck(jnp.linalg.solve, args_maker)
@jtu.sample_product(
lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)],
rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)]
)
def testSolveBroadcasting(self, lhs_shape, rhs_shape):
# Batched solve can involve some ambiguities; this test checks
# that we match NumPy's convention in all cases.
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')]
# As of numpy 1.26.3, np.linalg.solve fails when this condition is not met.
if len(lhs_shape) == 2 or len(rhs_shape) > 1:
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
self._CompileAndCheck(jnp.linalg.solve, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (2, 5, 5), (100, 100), (5, 5, 5), (0, 0)],
dtype=float_types,