mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Refine argument validation inside jax.scipy.sparse.linalg.cg (#3630)
Now we check tree structure and leaf shapes separately. This allow us to support pytrees that either don't define equality or that define it inconsistently (e.g., elementwise like NumPy) with builtin data structures like list/dict.
This commit is contained in:
parent
2a9c2d22cf
commit
36eb137dd3
@ -18,7 +18,8 @@ import operator
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import lax, device_put
|
||||
from jax.tree_util import tree_leaves, tree_map, tree_multimap
|
||||
from jax.tree_util import tree_leaves, tree_map, tree_multimap, tree_structure
|
||||
from jax.util import safe_map as map
|
||||
|
||||
|
||||
def _vdot_real_part(x, y):
|
||||
@ -85,6 +86,10 @@ def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
return x_final
|
||||
|
||||
|
||||
def _shapes(pytree):
|
||||
return map(jnp.shape, tree_leaves(pytree))
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
|
||||
|
||||
@ -150,10 +155,15 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
if M is None:
|
||||
M = _identity
|
||||
|
||||
shape = partial(tree_map, lambda x: x.shape)
|
||||
if shape(x0) != shape(b):
|
||||
if tree_structure(x0) != tree_structure(b):
|
||||
raise ValueError(
|
||||
f'x0 and b must have matching shape: {shape(x0)} vs {shape(b)}')
|
||||
'x0 and b must have matching tree structure: '
|
||||
f'{tree_structure(x0)} vs {tree_structure(b)}')
|
||||
|
||||
if _shapes(x0) != _shapes(b):
|
||||
raise ValueError(
|
||||
'arrays in x0 and b must have matching shapes: '
|
||||
f'{_shapes(x0)} vs {_shapes(b)}')
|
||||
|
||||
cg_solve = partial(
|
||||
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
|
||||
|
@ -23,6 +23,7 @@ from jax import jit
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
import jax.scipy.sparse.linalg
|
||||
|
||||
from jax.config import config
|
||||
@ -151,11 +152,31 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
def test_cg_errors(self):
|
||||
A = lambda x: x
|
||||
b = jnp.zeros((2, 1))
|
||||
x0 = jnp.zeros((2,))
|
||||
b = jnp.zeros((2,))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "x0 and b must have matching tree structure"):
|
||||
jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "x0 and b must have matching shape"):
|
||||
jax.scipy.sparse.linalg.cg(A, b, x0)
|
||||
jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
|
||||
|
||||
def test_cg_without_pytree_equality(self):
|
||||
|
||||
@register_pytree_node_class
|
||||
class MinimalPytree:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
def tree_flatten(self):
|
||||
return [self.value], None
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
A = lambda x: MinimalPytree(2 * x.value)
|
||||
b = MinimalPytree(jnp.arange(5.0))
|
||||
expected = b.value / 2
|
||||
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
||||
self.assertAllClose(expected, actual.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user