mirror of
https://github.com/ROCm/jax.git
synced 2025-04-22 09:06:06 +00:00
286 lines
9.2 KiB
Python
286 lines
9.2 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License
|
|
|
|
"""A JIT-compatible library for QDWH-based polar decomposition.
|
|
|
|
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
|
iteration implemented through QR decmopositions does not require matrix
|
|
inversion. This is desirable for multicore and heterogeneous computing systems.
|
|
|
|
Reference: Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
|
|
"Optimizing Halley's iteration for computing the matrix polar decomposition."
|
|
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
|
|
https://epubs.siam.org/doi/abs/10.1137/090774999
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import lax
|
|
from jax._src import core
|
|
from jax._src.lax import linalg as lax_linalg
|
|
|
|
|
|
# Helpers for working with padded shapes
|
|
def _mask(x, dims, alternative=0):
|
|
"""Masks `x` up to the dynamic shape `dims`.
|
|
|
|
Replaces values outside those dimensions with `alternative`. `alternative` is
|
|
broadcast with `x`.
|
|
"""
|
|
assert jnp.ndim(x) == len(dims)
|
|
mask = None
|
|
for i, d in enumerate(dims):
|
|
if d is not None:
|
|
mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d
|
|
mask = mask_dim_i if mask is None else (mask & mask_dim_i)
|
|
return x if mask is None else jnp.where(mask, x, alternative)
|
|
|
|
def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0):
|
|
pads = [(0, 0, 0)] * x.ndim
|
|
pads[axis] = (low, high, interior)
|
|
return lax.pad(x, jnp.array(fill_value, x.dtype), pads)
|
|
|
|
def _dynamic_concat(a, b, m, axis=0):
|
|
"Concatenates padded arrays `a` and `b` where the true size of `a` is `m`."
|
|
if m is None:
|
|
return jnp.concatenate([a, b], axis=axis)
|
|
return lax.dynamic_update_slice_in_dim(
|
|
_pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)
|
|
|
|
|
|
def _use_qr(u, m, n, params):
|
|
"""QDWH iteration using QR decomposition.
|
|
|
|
Args:
|
|
u: a matrix, with static (padded) shape M x N.
|
|
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
|
|
params: the QDWH parameters.
|
|
"""
|
|
a_minus_e_by_sqrt_c, sqrt_c, e = params
|
|
M, N = u.shape
|
|
|
|
y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
|
|
q, _ = lax_linalg.qr(y, full_matrices=False)
|
|
# q1 = q[:m, :]
|
|
q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
|
|
# q2 = (q[m:, :]).T.conj()
|
|
q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
|
|
q2 = _mask(q2, (n, n)).T.conj()
|
|
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2)
|
|
|
|
|
|
def _use_cholesky(u, m, n, params):
|
|
"""QDWH iteration using Cholesky decomposition.
|
|
|
|
Args:
|
|
u: a matrix, with static (padded) shape M x N
|
|
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
|
|
params: the QDWH parameters.
|
|
"""
|
|
a_minus_e, c, e = params
|
|
_, N = u.shape
|
|
x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u))
|
|
# Pads the lower-right corner with the identity matrix to prevent the Cholesky
|
|
# decomposition from failing due to the matrix not being PSD if padded with
|
|
# zeros.
|
|
x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))
|
|
|
|
# `y` is lower triangular.
|
|
y = lax_linalg.cholesky(x, symmetrize_input=False)
|
|
|
|
z = lax_linalg.triangular_solve(
|
|
y, u.T, left_side=True, lower=True, conjugate_a=True).conj()
|
|
|
|
z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True,
|
|
transpose_a=True, conjugate_a=True).T.conj()
|
|
|
|
return e * u + a_minus_e * z
|
|
|
|
|
|
def _qdwh(x, m, n, max_iterations, eps):
|
|
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
|
|
|
|
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
|
|
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
|
|
# the smallest singular value of x.
|
|
if eps is None:
|
|
eps = float(jnp.finfo(x.dtype).eps)
|
|
one_norm = jnp.linalg.norm(x, ord=1)
|
|
inf_norm = jnp.linalg.norm(x, ord=jnp.inf)
|
|
alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm)
|
|
alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse)
|
|
u = x * alpha_inverse.astype(x.dtype)
|
|
|
|
l = eps
|
|
|
|
# Iteration tolerances.
|
|
tol_l = 10.0 * eps / 2.0
|
|
tol_norm = jnp.cbrt(tol_l)
|
|
|
|
def get_qr_params(a, b, c):
|
|
e = b / c
|
|
a_minus_e = a - e
|
|
sqrt_c = c ** (1 / 2)
|
|
return (a_minus_e / sqrt_c, sqrt_c, e)
|
|
|
|
def get_chol_params(a, b, c):
|
|
e = b / c
|
|
a_minus_e = a - e
|
|
return (a_minus_e, c, e)
|
|
|
|
CHOLESKY_CUTOFF = 100
|
|
|
|
qr_coefs = []
|
|
chol_coefs = []
|
|
k = 0
|
|
while l + tol_l < 1 and k < max_iterations:
|
|
k += 1
|
|
l2 = l * l
|
|
dd = (4 * (1 / l2 - 1) / l2) ** (1 / 3)
|
|
sqd = (1.0 + dd) ** (1 / 2)
|
|
a = sqd + (2 - dd + 2 * (2 - l2) / (l2 * sqd)) ** (1 / 2)
|
|
b = (a - 1) ** 2 / 4
|
|
c = a + b - 1
|
|
l = l * (a + b * l2) / (1 + c * l2)
|
|
if c > CHOLESKY_CUTOFF:
|
|
qr_coefs.append(get_qr_params(a, b, c))
|
|
else:
|
|
chol_coefs.append(get_chol_params(a, b, c))
|
|
|
|
def iteration(k, state, update_fn, coefs, test_convergence):
|
|
u, _ = state
|
|
|
|
if coefs is None:
|
|
# As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method.
|
|
params = get_chol_params(3, 1, 3)
|
|
else:
|
|
params = lax.dynamic_index_in_dim(coefs, k, keepdims=False)
|
|
|
|
u_prev = u
|
|
u = update_fn(u, m, n, params)
|
|
|
|
is_not_converged = True
|
|
if test_convergence:
|
|
is_not_converged = jnp.linalg.norm(u - u_prev) > tol_norm
|
|
return u, is_not_converged
|
|
|
|
def iterate(u, coefs, **kwargs):
|
|
if not coefs:
|
|
return u, True
|
|
coefs = jnp.array(coefs).astype(x.dtype)
|
|
body = functools.partial(iteration, coefs=coefs, **kwargs)
|
|
return lax.fori_loop(0, len(coefs), body, (u, True))
|
|
|
|
u, _ = iterate(
|
|
u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False
|
|
)
|
|
u, is_not_converged = iterate(
|
|
u, coefs=chol_coefs, update_fn=_use_cholesky, test_convergence=True
|
|
)
|
|
|
|
# If l has converged but u still has not, continue with Halley's method
|
|
# (coef = None) until convergence.
|
|
def cond_fun(state):
|
|
k, _, is_not_converged = state
|
|
return jnp.logical_and(is_not_converged, k < max_iterations)
|
|
|
|
def body_fun(state):
|
|
k, u, is_not_converged = state
|
|
u, is_not_converged = iteration(
|
|
k,
|
|
(u, is_not_converged),
|
|
coefs=None,
|
|
update_fn=_use_cholesky,
|
|
test_convergence=True,
|
|
)
|
|
return k + 1, u, is_not_converged
|
|
|
|
k = len(qr_coefs) + len(chol_coefs)
|
|
num_iters, u, is_not_converged = lax.while_loop(
|
|
cond_fun, body_fun, (k, u, is_not_converged)
|
|
)
|
|
|
|
# Applies Newton-Schulz refinement for better accuracy.
|
|
u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)
|
|
|
|
h = u.T.conj() @ x
|
|
h = (h + h.T.conj()) / 2
|
|
|
|
# Converged within the maximum number of iterations.
|
|
is_converged = jnp.logical_not(is_not_converged)
|
|
|
|
return u, h, num_iters, is_converged
|
|
|
|
|
|
# TODO: Add pivoting.
|
|
@functools.partial(
|
|
jax.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps')
|
|
)
|
|
def qdwh(
|
|
x,
|
|
*,
|
|
is_hermitian: bool = False,
|
|
max_iterations: int | None = None,
|
|
eps: float | None = None,
|
|
dynamic_shape: tuple[int, int] | None = None,
|
|
):
|
|
"""QR-based dynamically weighted Halley iteration for polar decomposition.
|
|
|
|
Args:
|
|
x: A full-rank matrix, with shape `M x N`. The matrix may be padded up to
|
|
that size from a smaller true shape (``dynamic_shape``).
|
|
is_hermitian: True if `x` is Hermitian. Default to `False`. This parameter
|
|
is currently unused, but exists for backward compatibility.
|
|
eps: The final result will satisfy ``|x_k - x_k-1| < |x_k| *
|
|
(4*eps)**(1/3)`` where `x_k` is the iterate.
|
|
max_iterations: Iterations will terminate after this many steps even if the
|
|
above is unsatisfied.
|
|
dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional.
|
|
|
|
Returns:
|
|
A four-tuple of (u, h, num_iters, is_converged) containing the
|
|
polar decomposition of `x = u * h`, the number of iterations to compute `u`,
|
|
and `is_converged`, whose value is `True` when the convergence is achieved
|
|
within the maximum number of iterations.
|
|
"""
|
|
# TODO: Possibly take advantage of Hermitian inputs to speed up the QDWH step.
|
|
is_hermitian = core.concrete_or_error(
|
|
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
|
|
'specified to use `qdwh` within JAX transformations.')
|
|
|
|
if max_iterations is None:
|
|
max_iterations = 10
|
|
else:
|
|
max_iterations = core.concrete_or_error(
|
|
int, max_iterations, 'The `max_iterations` argument must be statically '
|
|
'specified to use `qdwh` within JAX transformations.')
|
|
|
|
M, N = x.shape
|
|
if M < N:
|
|
raise ValueError('The input matrix of shape M x N must have M >= N.')
|
|
if dynamic_shape is not None:
|
|
m, n = dynamic_shape
|
|
x = _mask(x, (m, n))
|
|
else:
|
|
m, n = M, N
|
|
|
|
with jax.default_matmul_precision('float32'):
|
|
u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps)
|
|
|
|
return u, h, num_iters, is_converged
|