rocm_jax/tests/lobpcg_test.py
Peter Hawkins c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00

438 lines
15 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lobpcg routine.
If LOBPCG_DEBUG_PLOT_DIR is set, exports debug visuals to that directory.
Requires matplotlib.
"""
import functools
import re
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.linalg as sla
import scipy.sparse as sps
import jax
from jax._src import test_util as jtu
from jax.experimental.sparse import linalg, bcoo
import jax.numpy as jnp
def _clean_matrix_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
def _make_concrete_cases(f64):
dtype = np.float64 if f64 else np.float32
example_names = list(_concrete_generators(dtype))
cases = []
for name in example_names:
n, k, m, tol = 100, 10, 20, None
if name == 'ring laplacian':
m *= 3
if name.startswith('linear'):
m *= 2
if f64:
m *= 2
if name.startswith('cluster') and not f64:
tol = 2e-6
clean_matrix_name = _clean_matrix_name(name)
case = {
'matrix_name': name,
'n': n,
'k': k,
'm': m,
'tol': tol,
'testcase_name': f'{clean_matrix_name}_n{n}'
}
cases.append(case)
assert len({c['testcase_name'] for c in cases}) == len(cases)
return cases
def _make_callable_cases(f64):
dtype = np.float64 if f64 else np.float32
example_names = list(_callable_generators(dtype))
return [{'testcase_name': _clean_matrix_name(n), 'matrix_name': n}
for n in example_names]
def _make_ring(n):
# from lobpcg scipy tests
col = np.zeros(n)
col[1] = 1
A = sla.toeplitz(col)
D = np.diag(A.sum(axis=1))
L = D - A
# Compute the full eigendecomposition using tricks, e.g.
# http://www.cs.yale.edu/homes/spielman/561/2009/lect02-09.pdf
tmp = np.pi * np.arange(n) / n
analytic_w = 2 * (1 - np.cos(tmp))
analytic_w.sort()
analytic_w = analytic_w[::-1]
return L, analytic_w
def _make_diag(diag):
diag.sort()
diag = diag[::-1]
return np.diag(diag), diag
def _make_cluster(to_cluster, n):
return _make_diag(
np.array([1000] * to_cluster + [1] * (n - to_cluster)))
def _concrete_generators(dtype):
d = {
'id': lambda n, _k: _make_diag(np.ones(n)),
'linear cond=1k': lambda n, _k: _make_diag(np.linspace(1, 1000, n)),
'linear cond=100k':
lambda n, _k: _make_diag(np.linspace(1, 100 * 1000, n)),
'geom cond=1k': lambda n, _k: _make_diag(np.logspace(0, 3, n)),
'geom cond=100k': lambda n, _k: _make_diag(np.logspace(0, 5, n)),
'ring laplacian': lambda n, _k: _make_ring(n),
'cluster(k/2)': lambda n, k: _make_cluster(k // 2, n),
'cluster(k-1)': lambda n, k: _make_cluster(k - 1, n),
'cluster(k)': lambda n, k: _make_cluster(k, n)}
def cast_fn(fn):
def casted_fn(n, k):
result = fn(n, k)
cast = functools.partial(np.array, dtype=dtype)
return tuple(map(cast, result))
return casted_fn
return {k: cast_fn(v) for k, v in d.items()}
def _make_id_fn(n):
return lambda x: x, np.ones(n), 5
def _make_diag_fn(diagonal, m):
return lambda x: diagonal.astype(x.dtype) * x, diagonal, m
def _make_ring_fn(n, m):
_, eigs = _make_ring(n)
def ring_action(x):
degree = 2 * x
lnbr = jnp.roll(x, 1)
rnbr = jnp.roll(x, -1)
return degree - lnbr - rnbr
return ring_action, eigs, m
def _make_randn_fn(n, k, m):
rng = np.random.default_rng(1234)
tall_skinny = rng.standard_normal((n, k))
def randn_action(x):
ts = jnp.array(tall_skinny, dtype=x.dtype)
p = jax.lax.Precision.HIGHEST
return ts.dot(ts.T.dot(x, precision=p), precision=p)
_, s, _ = np.linalg.svd(tall_skinny, full_matrices=False)
return randn_action, s ** 2, m
def _make_sparse_fn(n, fill):
rng = np.random.default_rng(1234)
slots = n ** 2
filled = max(int(slots * fill), 1)
pos = rng.choice(slots, size=filled, replace=False)
posx, posy = divmod(pos, n)
data = rng.standard_normal(len(pos))
coo = sps.coo_matrix((data, (posx, posy)), shape=(n, n))
def sparse_action(x):
coo_typed = coo.astype(np.dtype(x.dtype))
sps_mat = bcoo.BCOO.from_scipy_sparse(coo_typed)
dn = (((1,), (0,)), ((), ())) # Good old fashioned matmul.
x = bcoo.bcoo_dot_general(sps_mat, x, dimension_numbers=dn)
sps_mat_T = sps_mat.transpose()
return bcoo.bcoo_dot_general(sps_mat_T, x, dimension_numbers=dn)
dense = coo.todense()
_, s, _ = np.linalg.svd(dense, full_matrices=False)
return sparse_action, s ** 2, 20
def _callable_generators(dtype):
n = 100
topk = 10
d = {
'id': _make_id_fn(n),
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
'linear cond=100k': _make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
'ring laplacian': _make_ring_fn(n, 40),
'randn': _make_randn_fn(n, topk, 40),
'sparse 1%': _make_sparse_fn(n, 0.01),
'sparse 10%': _make_sparse_fn(n, 0.10),
}
ret = {}
for k, (vec_mul_fn, eigs, m) in d.items():
if jtu.num_float_bits(dtype) > 32:
m *= 3
eigs.sort()
# Note we must lift the vector multiply into matmul
fn = jax.vmap(vec_mul_fn, in_axes=1, out_axes=1)
ret[k] = (fn, eigs[::-1][:topk].astype(dtype), n, m)
return ret
@jtu.with_config(
jax_enable_checks=True,
jax_debug_nans=True,
jax_numpy_rank_promotion='raise',
jax_traceback_filtering='off')
@jtu.thread_unsafe_test_class() # matplotlib isn't thread-safe
class LobpcgTest(jtu.JaxTestCase):
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
A, X = (jnp.array(i, dtype=dtype) for i in (A, X))
theta, U, i = linalg.lobpcg_standard(A, X, m, tol)
self.assertDtypesMatch(theta, A)
self.assertDtypesMatch(U, A)
self.assertLess(
i, m, msg=f'expected early convergence iters {int(i)} < max {m}')
issorted = theta[:-1] >= theta[1:]
all_true = np.ones_like(issorted).astype(bool)
self.assertArraysEqual(issorted, all_true)
k = X.shape[1]
relerr = np.abs(theta - eigs[:k]) / eigs[:k]
for i in range(k):
# The self-consistency property should be ensured.
u = np.asarray(U[:, i], dtype=A.dtype)
t = float(theta[i])
Au = A.dot(u)
resid = Au - t * u
resid_norm = np.linalg.norm(resid)
vector_norm = np.linalg.norm(Au)
adjusted_error = resid_norm / n / (t + vector_norm) / 10
eps = float(jnp.finfo(dtype).eps) if tol is None else tol
self.assertLessEqual(
adjusted_error,
eps,
msg=f'convergence criterion for eigenvalue {i} not satisfied, '
f'floating point error {adjusted_error} not <= {eps}')
# There's no real guarantee we can be within x% of the true eigenvalue.
# However, for these simple unit test examples this should be met.
tol = float(np.sqrt(eps)) * 10
self.assertLessEqual(
relerr[i],
tol,
msg=f'expected relative error within {tol}, was {float(relerr[i])}'
f' for eigenvalue {i} (actual {float(theta[i])}, '
f'expected {float(eigs[i])})')
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype):
del tol
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
_theta, _U, _i, info = linalg._lobpcg_standard_matrix(
A, X, m, tol=0, debug=True)
self.assertArraysEqual(info['X zeros'], jnp.zeros_like(info['X zeros']))
# To check for any divergence, make sure that the last 20% of
# steps have lower worst-case relerr than first 20% of steps,
# at least up to an order of magnitude.
#
# This is non-trivial, as many implementations have catastrophic
# cancellations at convergence for residual terms, and rely on
# brittle locking tolerance to avoid divergence.
eigs = eigs[:k]
relerrs = np.abs(np.array(info['lambda history']) - eigs) / eigs
few_steps = max(m // 5, 1)
self.assertLess(
relerrs[-few_steps:].max(axis=1).mean(),
10 * relerrs[:few_steps].max(axis=1).mean())
self._possibly_plot(A, eigs, X, m, matrix_name)
def _possibly_plot(self, A, eigs, X, m, matrix_name):
if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'):
return
if isinstance(A, (np.ndarray, jax.Array)):
lobpcg = linalg._lobpcg_standard_matrix
else:
lobpcg = linalg._lobpcg_standard_callable
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
plot_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR')
assert plot_dir, 'expected TEST_UNDECLARED_OUTPUTS_DIR for lobpcg plots'
self._debug_plots(X, eigs, info, matrix_name, plot_dir)
def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir):
# We import matplotlib lazily because (a) it's faster this way, and
# (b) concurrent imports of matplotlib appear to trigger some sort of
# collision on the matplotlib cache lock on Windows.
try:
from matplotlib import pyplot as plt
except (ModuleNotFoundError, ImportError):
return # If matplotlib isn't available, don't emit plots.
os.makedirs(lobpcg_debug_plot_dir, exist_ok=True)
clean_matrix_name = _clean_matrix_name(matrix_name)
n, k = X.shape
dt = 'f32' if X.dtype == np.float32 else 'f64'
figpath = os.path.join(
lobpcg_debug_plot_dir,
f'{clean_matrix_name}_n{n}_k{k}_{dt}.png')
plt.switch_backend('Agg')
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(24, 4))
fig.suptitle(fr'{matrix_name} ${n=},{k=}$, {dt}')
line_styles = [':', '--', '-.', '-']
for key, ls in zip(['X orth', 'P orth', 'P.X'], line_styles):
ax0.semilogy(info[key], ls=ls, label=key)
ax0.set_title('basis average orthogonality')
ax0.legend()
relerrs = np.abs(np.array(info['lambda history']) - eigs) / eigs
keys = ['max', 'p50', 'min']
fns = [np.max, np.median, np.min]
for key, fn, ls in zip(keys, fns, line_styles):
ax1.semilogy(fn(relerrs, axis=1), ls=ls, label=key)
ax1.set_title('eigval relerr')
ax1.legend()
for key, ls in zip(['basis rank', 'converged', 'P zeros'], line_styles):
ax2.plot(info[key], ls=ls, label=key)
ax2.set_title('basis dimension counts')
ax2.legend()
prefix = 'adjusted residual'
for key, ls in zip(keys, line_styles):
ax3.semilogy(info[prefix + ' ' + key], ls=ls, label=key)
ax3.axhline(np.finfo(X.dtype).eps, label='eps', c='k')
ax3.legend()
ax3.set_title(prefix + rf' $\lambda_{{\max}}=\ ${eigs[0]:.1e}')
fig.savefig(figpath, bbox_inches='tight')
plt.close(fig)
def checkApproxEigs(self, example_name, dtype):
fn, eigs, n, m = _callable_generators(dtype)[example_name]
k = len(eigs)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
theta, U, iters = linalg.lobpcg_standard(fn, X, m, tol=0.0)
# Given tolerance is zero all iters should be used.
self.assertEqual(iters, m)
# Evaluate in f64.
as64 = functools.partial(np.array, dtype=np.float64)
theta, eigs, U = (as64(x) for x in (theta, eigs, U))
relerr = np.abs(theta - eigs) / eigs
UTU = U.T.dot(U)
tol = np.sqrt(jnp.finfo(dtype).eps) * 100
if example_name == 'ring laplacian':
tol = 1e-2
for i in range(k):
self.assertLessEqual(
relerr[i], tol,
msg=f'eigenvalue {i} (actual {theta[i]} expected {eigs[i]})')
self.assertAllClose(UTU[i, i], 1.0, rtol=tol)
UTU[i, i] = 0
self.assertArraysAllClose(UTU[i], np.zeros_like(UTU[i]), atol=tol)
self._possibly_plot(fn, eigs, X, m, 'callable_' + example_name)
class F32LobpcgTest(LobpcgTest):
def setUp(self):
# TODO(phawkins): investigate this failure
if jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Test is failing on CUDA gpus")
super().setUp()
def testLobpcgValidatesArguments(self):
A, _ = _concrete_generators(np.float32)['id'](100, 10)
X = self.rng().standard_normal(size=(100, 10)).astype(np.float32)
with self.assertRaisesRegex(ValueError, 'search dim > 0'):
linalg.lobpcg_standard(A, X[:,:0])
with self.assertRaisesRegex(ValueError, 'A, X must have same dtypes'):
linalg.lobpcg_standard(
lambda x: jnp.array(A).dot(x).astype(jnp.float16), X)
with self.assertRaisesRegex(ValueError, r'A must be \(100, 100\)'):
linalg.lobpcg_standard(A[:60, :], X)
with self.assertRaisesRegex(ValueError, r'search dim \* 5 < matrix dim'):
linalg.lobpcg_standard(A[:50, :50], X[:50])
@parameterized.named_parameters(_make_concrete_cases(f64=False))
@jtu.skip_on_devices("gpu")
def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_concrete_cases(f64=False))
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_callable_cases(f64=False))
def testCallableMatricesF32(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float32)
@jtu.with_config(jax_enable_x64=True)
class F64LobpcgTest(LobpcgTest):
def setUp(self):
# TODO(phawkins): investigate this failure
if jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Test is failing on CUDA gpus")
super().setUp()
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "gpu")
def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "gpu")
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_callable_cases(f64=True))
@jtu.skip_on_devices("tpu", "gpu")
def testCallableMatricesF64(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float64)
if __name__ == '__main__':
jax.config.parse_flags_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())