Fix some busted batching rules in lax.linalg.

PiperOrigin-RevId: 726543703
This commit is contained in:
Dan Foreman-Mackey 2025-02-13 10:27:52 -08:00 committed by jax authors
parent 7f999298ac
commit ea4e324fe4
2 changed files with 11 additions and 3 deletions

View File

@ -1976,7 +1976,7 @@ def _householder_product_batching_rule(batched_args, batch_dims):
a, taus = batched_args
b_a, b_taus, = batch_dims
return householder_product(batching.moveaxis(a, b_a, 0),
batching.moveaxis(taus, b_taus, 0)), (0,)
batching.moveaxis(taus, b_taus, 0)), 0
def _householder_product_lowering_rule(ctx, a, taus):
aval_out, = ctx.avals_out
@ -2865,7 +2865,7 @@ def _hessenberg_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return hessenberg(x), 0
return hessenberg(x), (0, 0)
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
@ -2961,7 +2961,7 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal(x, lower=lower), 0
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0, 0)
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule

View File

@ -1766,6 +1766,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
check_dtypes=not calc_q)
self._CompileAndCheck(jsp_func, args_maker)
if len(shape) == 3:
args = args_maker()
self.assertAllClose(jax.vmap(jsp_func)(*args), jsp_func(*args))
@jtu.sample_product(
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
dtype=float_types + complex_types,
@ -1798,6 +1802,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
check_dtypes=False)
if len(shape) == 3:
args = args_maker()
self.assertAllClose(jax.vmap(jax_func)(*args), jax_func(*args))
@jtu.sample_product(
n=[1, 4, 5, 20, 50, 100],
dtype=float_types + complex_types,