mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19224 from jakevdp:batched-solve
PiperOrigin-RevId: 596313955
This commit is contained in:
commit
16699e4e78
@ -608,9 +608,13 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]:
|
||||
def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
check_arraylike("jnp.linalg.solve", a, b)
|
||||
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
||||
if a.ndim >= 2 and b.ndim > a.ndim:
|
||||
a = lax.expand_dims(a, tuple(range(b.ndim - a.ndim)))
|
||||
return lax_linalg._solve(a, b)
|
||||
# TODO(jakevdp): this condition matches the broadcasting behavior in numpy < 2.0.
|
||||
# For the array API specification, we would check only if b.ndim == 1.
|
||||
if b.ndim == 1 or a.ndim == b.ndim + 1:
|
||||
signature = "(m,m),(m)->(m)"
|
||||
else:
|
||||
signature = "(m,m),(m,n)->(m,n)"
|
||||
return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)
|
||||
|
||||
|
||||
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
|
||||
|
@ -1020,6 +1020,20 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.solve, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)],
|
||||
rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)]
|
||||
)
|
||||
def testSolveBroadcasting(self, lhs_shape, rhs_shape):
|
||||
# Batched solve can involve some ambiguities; this test checks
|
||||
# that we match NumPy's convention in all cases.
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')]
|
||||
# As of numpy 1.26.3, np.linalg.solve fails when this condition is not met.
|
||||
if len(lhs_shape) == 2 or len(rhs_shape) > 1:
|
||||
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3)
|
||||
self._CompileAndCheck(jnp.linalg.solve, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (2, 5, 5), (100, 100), (5, 5, 5), (0, 0)],
|
||||
dtype=float_types,
|
||||
|
Loading…
x
Reference in New Issue
Block a user