rocm_jax/tests/lobpcg_test.py
Peter Hawkins 6b76937c53 Remove matplotlib from the test requirements.
In the Windows CI, we seem to be hitting the following error:

```
=================================== ERRORS ====================================
____________________ ERROR collecting tests/lobpcg_test.py ____________________
tests\lobpcg_test.py:28: in <module>
    from matplotlib import pyplot as plt
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\pyplot.py:52: in <module>
    import matplotlib.colorbar
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\colorbar.py:19: in <module>
    from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\contour.py:13: in <module>
    from matplotlib.backend_bases import MouseButton
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\backend_bases.py:45: in <module>
    from matplotlib import (
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\text.py:16: in <module>
    from .font_manager import FontProperties
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1548: in <module>
    fontManager = _load_fontmanager()
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1543: in _load_fontmanager
    json_dump(fm, fm_path)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:957: in json_dump
    with cbook._lock_path(filename), open(filename, 'w') as fh:
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\contextlib.py:119: in __enter__
    return next(self.gen)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\cbook\__init__.py:1804: in _lock_path
    with lock_path.open("xb"):
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1252: in open
    return io.open(self, mode, buffering, encoding, errors, newline,
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1120: in _opener
    return self._accessor.open(self, flags, mode)
E   PermissionError: [Errno 13] Permission denied: 'C:\\Users\\runneradmin\\.matplotlib\\fontlist-v330.json.matplotlib-lock'
```

The use of matplotlib is only for an optional debugging feature anyway, so just make it an optional dependency.
2023-06-14 09:02:49 -04:00

421 lines
14 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lobpcg routine.
If LOBPCG_DEBUG_PLOT_DIR is set, exports debug visuals to that directory.
Requires matplotlib.
"""
import functools
import re
import os
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.linalg as sla
import scipy.sparse as sps
import jax
from jax import config
from jax._src import test_util as jtu
from jax.experimental.sparse import linalg, bcoo
import jax.numpy as jnp
def _clean_matrix_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
def _make_concrete_cases(f64):
dtype = np.float64 if f64 else np.float32
example_names = list(_concrete_generators(dtype))
cases = []
for name in example_names:
n, k, m, tol = 100, 10, 20, None
if name == 'ring laplacian':
m *= 3
if name.startswith('linear'):
m *= 2
if f64:
m *= 2
if name.startswith('cluster') and not f64:
tol = 2e-6
clean_matrix_name = _clean_matrix_name(name)
case = {
'matrix_name': name,
'n': n,
'k': k,
'm': m,
'tol': tol,
'testcase_name': f'{clean_matrix_name}_n{n}'
}
cases.append(case)
assert len({c['testcase_name'] for c in cases}) == len(cases)
return cases
def _make_callable_cases(f64):
dtype = np.float64 if f64 else np.float32
example_names = list(_callable_generators(dtype))
return [{'testcase_name': _clean_matrix_name(n), 'matrix_name': n}
for n in example_names]
def _make_ring(n):
# from lobpcg scipy tests
col = np.zeros(n)
col[1] = 1
A = sla.toeplitz(col)
D = np.diag(A.sum(axis=1))
L = D - A
# Compute the full eigendecomposition using tricks, e.g.
# http://www.cs.yale.edu/homes/spielman/561/2009/lect02-09.pdf
tmp = np.pi * np.arange(n) / n
analytic_w = 2 * (1 - np.cos(tmp))
analytic_w.sort()
analytic_w = analytic_w[::-1]
return L, analytic_w
def _make_diag(diag):
diag.sort()
diag = diag[::-1]
return np.diag(diag), diag
def _make_cluster(to_cluster, n):
return _make_diag(
np.array([1000] * to_cluster + [1] * (n - to_cluster)))
def _concrete_generators(dtype):
d = {
'id': lambda n, _k: _make_diag(np.ones(n)),
'linear cond=1k': lambda n, _k: _make_diag(np.linspace(1, 1000, n)),
'linear cond=100k':
lambda n, _k: _make_diag(np.linspace(1, 100 * 1000, n)),
'geom cond=1k': lambda n, _k: _make_diag(np.logspace(0, 3, n)),
'geom cond=100k': lambda n, _k: _make_diag(np.logspace(0, 5, n)),
'ring laplacian': lambda n, _k: _make_ring(n),
'cluster(k/2)': lambda n, k: _make_cluster(k // 2, n),
'cluster(k-1)': lambda n, k: _make_cluster(k - 1, n),
'cluster(k)': lambda n, k: _make_cluster(k, n)}
def cast_fn(fn):
def casted_fn(n, k):
result = fn(n, k)
cast = functools.partial(np.array, dtype=dtype)
return tuple(map(cast, result))
return casted_fn
return {k: cast_fn(v) for k, v in d.items()}
def _make_id_fn(n):
return lambda x: x, np.ones(n), 5
def _make_diag_fn(diagonal, m):
return lambda x: diagonal.astype(x.dtype) * x, diagonal, m
def _make_ring_fn(n, m):
_, eigs = _make_ring(n)
def ring_action(x):
degree = 2 * x
lnbr = jnp.roll(x, 1)
rnbr = jnp.roll(x, -1)
return degree - lnbr - rnbr
return ring_action, eigs, m
def _make_randn_fn(n, k, m):
rng = np.random.default_rng(1234)
tall_skinny = rng.standard_normal((n, k))
def randn_action(x):
ts = jnp.array(tall_skinny, dtype=x.dtype)
p = jax.lax.Precision.HIGHEST
return ts.dot(ts.T.dot(x, precision=p), precision=p)
_, s, _ = np.linalg.svd(tall_skinny, full_matrices=False)
return randn_action, s ** 2, m
def _make_sparse_fn(n, fill):
rng = np.random.default_rng(1234)
slots = n ** 2
filled = max(int(slots * fill), 1)
pos = rng.choice(slots, size=filled, replace=False)
posx, posy = divmod(pos, n)
data = rng.standard_normal(len(pos))
coo = sps.coo_matrix((data, (posx, posy)), shape=(n, n))
def sparse_action(x):
coo_typed = coo.astype(np.dtype(x.dtype))
sps_mat = bcoo.BCOO.from_scipy_sparse(coo_typed)
dn = (((1,), (0,)), ((), ())) # Good old fashioned matmul.
x = bcoo.bcoo_dot_general(sps_mat, x, dimension_numbers=dn)
sps_mat_T = sps_mat.transpose()
return bcoo.bcoo_dot_general(sps_mat_T, x, dimension_numbers=dn)
dense = coo.todense()
_, s, _ = np.linalg.svd(dense, full_matrices=False)
return sparse_action, s ** 2, 20
def _callable_generators(dtype):
n = 100
topk = 10
d = {
'id': _make_id_fn(n),
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
'linear cond=100k': _make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
'ring laplacian': _make_ring_fn(n, 40),
'randn': _make_randn_fn(n, topk, 40),
'sparse 1%': _make_sparse_fn(n, 0.01),
'sparse 10%': _make_sparse_fn(n, 0.10),
}
ret = {}
for k, (vec_mul_fn, eigs, m) in d.items():
if jtu.num_float_bits(dtype) > 32:
m *= 3
eigs.sort()
# Note we must lift the vector multiply into matmul
fn = jax.vmap(vec_mul_fn, in_axes=1, out_axes=1)
ret[k] = (fn, eigs[::-1][:topk].astype(dtype), n, m)
return ret
@jtu.with_config(
jax_enable_checks=True,
jax_debug_nans=True,
jax_numpy_rank_promotion='raise',
jax_traceback_filtering='off')
class LobpcgTest(jtu.JaxTestCase):
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
A, X = (jnp.array(i, dtype=dtype) for i in (A, X))
theta, U, i = linalg.lobpcg_standard(A, X, m, tol)
self.assertDtypesMatch(theta, A)
self.assertDtypesMatch(U, A)
self.assertLess(
i, m, msg=f'expected early convergence iters {int(i)} < max {m}')
issorted = theta[:-1] >= theta[1:]
all_true = np.ones_like(issorted).astype(bool)
self.assertArraysEqual(issorted, all_true)
k = X.shape[1]
relerr = np.abs(theta - eigs[:k]) / eigs[:k]
for i in range(k):
# The self-consistency property should be ensured.
u = np.asarray(U[:, i], dtype=A.dtype)
t = float(theta[i])
Au = A.dot(u)
resid = Au - t * u
resid_norm = np.linalg.norm(resid)
vector_norm = np.linalg.norm(Au)
adjusted_error = resid_norm / n / (t + vector_norm) / 10
eps = float(jnp.finfo(dtype).eps) if tol is None else tol
self.assertLessEqual(
adjusted_error,
eps,
msg=f'convergence criterion for eigenvalue {i} not satisfied, '
f'floating point error {adjusted_error} not <= {eps}')
# There's no real guarantee we can be within x% of the true eigenvalue.
# However, for these simple unit test examples this should be met.
tol = float(np.sqrt(eps)) * 10
self.assertLessEqual(
relerr[i],
tol,
msg=f'expected relative error within {tol}, was {float(relerr[i])}'
f' for eigenvalue {i} (actual {float(theta[i])}, '
f'expected {float(eigs[i])})')
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype):
del tol
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
_theta, _U, _i, info = linalg._lobpcg_standard_matrix(
A, X, m, tol=0, debug=True)
self.assertArraysEqual(info['X zeros'], jnp.zeros_like(info['X zeros']))
# To check for any divergence, make sure that the last 20% of
# steps have lower worst-case relerr than first 20% of steps,
# at least up to an order of magnitude.
#
# This is non-trivial, as many implementations have catastrophic
# cancellations at convergence for residual terms, and rely on
# brittle locking tolerance to avoid divergence.
eigs = eigs[:k]
relerrs = np.abs(np.array(info['lambda history']) - eigs) / eigs
few_steps = max(m // 5, 1)
self.assertLess(
relerrs[-few_steps:].max(axis=1).mean(),
10 * relerrs[:few_steps].max(axis=1).mean())
self._possibly_plot(A, eigs, X, m, matrix_name)
def _possibly_plot(self, A, eigs, X, m, matrix_name):
if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'):
return
if isinstance(A, (np.ndarray, jax.Array)):
lobpcg = linalg._lobpcg_standard_matrix
else:
lobpcg = linalg._lobpcg_standard_callable
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
plot_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR')
assert plot_dir, 'expected TEST_UNDECLARED_OUTPUTS_DIR for lobpcg plots'
self._debug_plots(X, eigs, info, matrix_name, plot_dir)
def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir):
# We import matplotlib lazily because (a) it's faster this way, and
# (b) concurrent imports of matplotlib appear to trigger some sort of
# collision on the matplotlib cache lock on Windows.
from matplotlib import pyplot as plt
os.makedirs(lobpcg_debug_plot_dir, exist_ok=True)
clean_matrix_name = _clean_matrix_name(matrix_name)
n, k = X.shape
dt = 'f32' if X.dtype == np.float32 else 'f64'
figpath = os.path.join(
lobpcg_debug_plot_dir,
f'{clean_matrix_name}_n{n}_k{k}_{dt}.png')
plt.switch_backend('Agg')
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(24, 4))
fig.suptitle(fr'{matrix_name} ${n=},{k=}$, {dt}')
line_styles = [':', '--', '-.', '-']
for key, ls in zip(['X orth', 'P orth', 'P.X'], line_styles):
ax0.semilogy(info[key], ls=ls, label=key)
ax0.set_title('basis average orthogonality')
ax0.legend()
relerrs = np.abs(np.array(info['lambda history']) - eigs) / eigs
keys = ['max', 'p50', 'min']
fns = [np.max, np.median, np.min]
for key, fn, ls in zip(keys, fns, line_styles):
ax1.semilogy(fn(relerrs, axis=1), ls=ls, label=key)
ax1.set_title('eigval relerr')
ax1.legend()
for key, ls in zip(['basis rank', 'converged', 'P zeros'], line_styles):
ax2.plot(info[key], ls=ls, label=key)
ax2.set_title('basis dimension counts')
ax2.legend()
prefix = 'adjusted residual'
for key, ls in zip(keys, line_styles):
ax3.semilogy(info[prefix + ' ' + key], ls=ls, label=key)
ax3.axhline(np.finfo(X.dtype).eps, label='eps', c='k')
ax3.legend()
ax3.set_title(prefix + rf' $\lambda_{{\max}}=\ ${eigs[0]:.1e}')
fig.savefig(figpath, bbox_inches='tight')
plt.close(fig)
def checkApproxEigs(self, example_name, dtype):
fn, eigs, n, m = _callable_generators(dtype)[example_name]
k = len(eigs)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
theta, U, iters = linalg.lobpcg_standard(fn, X, m, tol=0.0)
# Given tolerance is zero all iters should be used.
self.assertEqual(iters, m)
# Evaluate in f64.
as64 = functools.partial(np.array, dtype=np.float64)
theta, eigs, U = (as64(x) for x in (theta, eigs, U))
relerr = np.abs(theta - eigs) / eigs
UTU = U.T.dot(U)
tol = np.sqrt(jnp.finfo(dtype).eps) * 100
if example_name == 'ring laplacian':
tol = 1e-2
for i in range(k):
self.assertLessEqual(
relerr[i], tol,
msg=f'eigenvalue {i} (actual {theta[i]} expected {eigs[i]})')
self.assertAllClose(UTU[i, i], 1.0, rtol=tol)
UTU[i, i] = 0
self.assertArraysAllClose(UTU[i], np.zeros_like(UTU[i]), atol=tol)
self._possibly_plot(fn, eigs, X, m, 'callable_' + example_name)
class F32LobpcgTest(LobpcgTest):
def testLobpcgValidatesArguments(self):
A, _ = _concrete_generators(np.float32)['id'](100, 10)
X = self.rng().standard_normal(size=(100, 10)).astype(np.float32)
with self.assertRaisesRegex(ValueError, 'search dim > 0'):
linalg.lobpcg_standard(A, X[:,:0])
with self.assertRaisesRegex(ValueError, 'A, X must have same dtypes'):
linalg.lobpcg_standard(
lambda x: jnp.array(A).dot(x).astype(jnp.float16), X)
with self.assertRaisesRegex(ValueError, r'A must be \(100, 100\)'):
linalg.lobpcg_standard(A[:60, :], X)
with self.assertRaisesRegex(ValueError, r'search dim \* 5 < matrix dim'):
linalg.lobpcg_standard(A[:50, :50], X[:50])
@parameterized.named_parameters(_make_concrete_cases(f64=False))
@jtu.skip_on_devices("gpu")
def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_concrete_cases(f64=False))
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_callable_cases(f64=False))
def testCallableMatricesF32(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float32)
@jtu.with_config(jax_enable_x64=True)
class F64LobpcgTest(LobpcgTest):
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "iree", "gpu")
def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "iree", "gpu")
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_callable_cases(f64=True))
@jtu.skip_on_devices("tpu", "iree", "gpu")
def testCallableMatricesF64(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float64)
if __name__ == '__main__':
config.parse_flags_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())