Refine argument validation inside jax.scipy.sparse.linalg.cg (#3630)

Now we check tree structure and leaf shapes separately. This allow us to
support pytrees that either don't define equality or that define it
inconsistently (e.g., elementwise like NumPy) with builtin data structures like
list/dict.
This commit is contained in:
Stephan Hoyer 2020-07-06 09:24:44 -07:00 committed by GitHub
parent 2a9c2d22cf
commit 36eb137dd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 7 deletions

View File

@ -18,7 +18,8 @@ import operator
import numpy as np
import jax.numpy as jnp
from jax import lax, device_put
from jax.tree_util import tree_leaves, tree_map, tree_multimap
from jax.tree_util import tree_leaves, tree_map, tree_multimap, tree_structure
from jax.util import safe_map as map
def _vdot_real_part(x, y):
@ -85,6 +86,10 @@ def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
return x_final
def _shapes(pytree):
return map(jnp.shape, tree_leaves(pytree))
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
@ -150,10 +155,15 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
if M is None:
M = _identity
shape = partial(tree_map, lambda x: x.shape)
if shape(x0) != shape(b):
if tree_structure(x0) != tree_structure(b):
raise ValueError(
f'x0 and b must have matching shape: {shape(x0)} vs {shape(b)}')
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
if _shapes(x0) != _shapes(b):
raise ValueError(
'arrays in x0 and b must have matching shapes: '
f'{_shapes(x0)} vs {_shapes(b)}')
cg_solve = partial(
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)

View File

@ -23,6 +23,7 @@ from jax import jit
import jax.numpy as jnp
from jax import lax
from jax import test_util as jtu
from jax.tree_util import register_pytree_node_class
import jax.scipy.sparse.linalg
from jax.config import config
@ -151,11 +152,31 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
def test_cg_errors(self):
A = lambda x: x
b = jnp.zeros((2, 1))
x0 = jnp.zeros((2,))
b = jnp.zeros((2,))
with self.assertRaisesRegex(
ValueError, "x0 and b must have matching tree structure"):
jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
with self.assertRaisesRegex(
ValueError, "x0 and b must have matching shape"):
jax.scipy.sparse.linalg.cg(A, b, x0)
jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
def test_cg_without_pytree_equality(self):
@register_pytree_node_class
class MinimalPytree:
def __init__(self, value):
self.value = value
def tree_flatten(self):
return [self.value], None
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
A = lambda x: MinimalPytree(2 * x.value)
b = MinimalPytree(jnp.arange(5.0))
expected = b.value / 2
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
self.assertAllClose(expected, actual.value)
if __name__ == "__main__":