rocm_jax/tests/magma_linalg_test.py
Peter Hawkins 66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00

144 lines
5.2 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
from absl.testing import absltest
import jax
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax import linalg as lax_linalg
from jax._src.lib import gpu_solver
config.parse_flags_with_absl()
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
class MagmaLinalgTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)],
dtype=float_types + complex_types,
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
@jtu.run_on_devices("gpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
# TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for
# complex128 in some configurations.
if dtype == np.complex128:
self.skipTest("MAGMA support for complex128 types is flaky.")
rng = jtu.rand_default(self.rng())
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
# Norm, adjusted for dimension and type.
def norm(x):
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * jnp.finfo(dtype).eps)
def check_right_eigenvectors(a, w, vr):
self.assertTrue(
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
def check_left_eigenvectors(a, w, vl):
rank = len(a.shape)
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
wC = jnp.conj(w)
check_right_eigenvectors(aH, wC, vl)
a, = args_maker()
results = lax_linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=True)
w = results[0]
if compute_left_eigenvectors:
check_left_eigenvectors(a, w, results[1])
if compute_right_eigenvectors:
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])
self._CompileAndCheck(jnp.linalg.eig, args_maker, rtol=1e-3)
@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
dtype=float_types + complex_types,
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
@jtu.run_on_devices("gpu")
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
# TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for
# complex128 in some configurations.
if dtype == np.complex128:
self.skipTest("MAGMA support for complex128 types is flaky.")
a = jnp.full(shape, jnp.nan, dtype)
results = lax_linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=True)
for result in results:
self.assertTrue(np.all(np.isnan(result)))
def testEigMagmaConfig(self):
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())
a = rng((5, 5), np.float32)
with config.gpu_use_magma("on"):
hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text()
self.assertIn('magma = "on"', hlo)
@jtu.sample_product(
shape=[(3, 4), (3, 3), (4, 3), (4, 3)],
dtype=float_types + complex_types,
)
@jtu.run_on_devices("gpu")
def testPivotedQrFactorization(self, shape, dtype):
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())
lax_func = partial(lax_linalg.qr, full_matrices=True, pivoting=True, use_magma=True)
sp_func = partial(jax.scipy.linalg.qr, mode="full", pivoting=True)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, lax_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(lax_func, args_maker)
def testPivotedQrFactorizationMagmaConfig(self):
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())
a = rng((5, 5), np.float32)
with config.gpu_use_magma("on"):
hlo = jax.jit(partial(lax_linalg.qr, pivoting=True, use_magma=True)).lower(a).as_text()
self.assertIn('magma = "on"', hlo)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())