mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix some busted batching rules in lax.linalg.
PiperOrigin-RevId: 726543703
This commit is contained in:
parent
7f999298ac
commit
ea4e324fe4
@ -1976,7 +1976,7 @@ def _householder_product_batching_rule(batched_args, batch_dims):
|
||||
a, taus = batched_args
|
||||
b_a, b_taus, = batch_dims
|
||||
return householder_product(batching.moveaxis(a, b_a, 0),
|
||||
batching.moveaxis(taus, b_taus, 0)), (0,)
|
||||
batching.moveaxis(taus, b_taus, 0)), 0
|
||||
|
||||
def _householder_product_lowering_rule(ctx, a, taus):
|
||||
aval_out, = ctx.avals_out
|
||||
@ -2865,7 +2865,7 @@ def _hessenberg_batching_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return hessenberg(x), 0
|
||||
return hessenberg(x), (0, 0)
|
||||
|
||||
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
|
||||
|
||||
@ -2961,7 +2961,7 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return tridiagonal(x, lower=lower), 0
|
||||
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0, 0)
|
||||
|
||||
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||
|
||||
|
@ -1766,6 +1766,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
check_dtypes=not calc_q)
|
||||
self._CompileAndCheck(jsp_func, args_maker)
|
||||
|
||||
if len(shape) == 3:
|
||||
args = args_maker()
|
||||
self.assertAllClose(jax.vmap(jsp_func)(*args), jsp_func(*args))
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
|
||||
dtype=float_types + complex_types,
|
||||
@ -1798,6 +1802,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
|
||||
check_dtypes=False)
|
||||
|
||||
if len(shape) == 3:
|
||||
args = args_maker()
|
||||
self.assertAllClose(jax.vmap(jax_func)(*args), jax_func(*args))
|
||||
|
||||
@jtu.sample_product(
|
||||
n=[1, 4, 5, 20, 50, 100],
|
||||
dtype=float_types + complex_types,
|
||||
|
Loading…
x
Reference in New Issue
Block a user