mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types. PiperOrigin-RevId: 412057514
180 lines
5.5 KiB
Python
180 lines
5.5 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""The Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
|
|
from functools import partial
|
|
from typing import Callable, NamedTuple, Optional, Union
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import lax
|
|
from jax._src.scipy.optimize.line_search import line_search
|
|
|
|
|
|
class _BFGSResults(NamedTuple):
|
|
"""Results from BFGS optimization.
|
|
|
|
Parameters:
|
|
converged: True if minimization converged.
|
|
failed: True if line search failed.
|
|
k: integer the number of iterations of the BFGS update.
|
|
nfev: integer total number of objective evaluations performed.
|
|
ngev: integer total number of jacobian evaluations
|
|
nhev: integer total number of hessian evaluations
|
|
x_k: array containing the last argument value found during the search. If
|
|
the search converged, then this value is the argmin of the objective
|
|
function.
|
|
f_k: array containing the value of the objective function at `x_k`. If the
|
|
search converged, then this is the (local) minimum of the objective
|
|
function.
|
|
g_k: array containing the gradient of the objective function at `x_k`. If
|
|
the search converged the l2-norm of this tensor should be below the
|
|
tolerance.
|
|
H_k: array containing the inverse of the estimated Hessian.
|
|
status: int describing end state.
|
|
line_search_status: int describing line search end state (only means
|
|
something if line search fails).
|
|
"""
|
|
converged: Union[bool, jnp.ndarray]
|
|
failed: Union[bool, jnp.ndarray]
|
|
k: Union[int, jnp.ndarray]
|
|
nfev: Union[int, jnp.ndarray]
|
|
ngev: Union[int, jnp.ndarray]
|
|
nhev: Union[int, jnp.ndarray]
|
|
x_k: jnp.ndarray
|
|
f_k: jnp.ndarray
|
|
g_k: jnp.ndarray
|
|
H_k: jnp.ndarray
|
|
old_old_fval: jnp.ndarray
|
|
status: Union[int, jnp.ndarray]
|
|
line_search_status: Union[int, jnp.ndarray]
|
|
|
|
|
|
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
|
_einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
def minimize_bfgs(
|
|
fun: Callable,
|
|
x0: jnp.ndarray,
|
|
maxiter: Optional[int] = None,
|
|
norm=jnp.inf,
|
|
gtol: float = 1e-5,
|
|
line_search_maxiter: int = 10,
|
|
) -> _BFGSResults:
|
|
"""Minimize a function using BFGS.
|
|
|
|
Implements the BFGS algorithm from
|
|
Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg.
|
|
136-143.
|
|
|
|
Args:
|
|
fun: function of the form f(x) where x is a flat ndarray and returns a real
|
|
scalar. The function should be composed of operations with vjp defined.
|
|
x0: initial guess.
|
|
maxiter: maximum number of iterations.
|
|
norm: order of norm for convergence check. Default inf.
|
|
gtol: terminates minimization when |grad|_norm < g_tol.
|
|
line_search_maxiter: maximum number of linesearch iterations.
|
|
|
|
Returns:
|
|
Optimization result.
|
|
"""
|
|
|
|
if maxiter is None:
|
|
maxiter = jnp.size(x0) * 200
|
|
|
|
d = x0.shape[0]
|
|
|
|
initial_H = jnp.eye(d, dtype=x0.dtype)
|
|
f_0, g_0 = jax.value_and_grad(fun)(x0)
|
|
state = _BFGSResults(
|
|
converged=jnp.linalg.norm(g_0, ord=norm) < gtol,
|
|
failed=False,
|
|
k=0,
|
|
nfev=1,
|
|
ngev=1,
|
|
nhev=0,
|
|
x_k=x0,
|
|
f_k=f_0,
|
|
g_k=g_0,
|
|
H_k=initial_H,
|
|
old_old_fval=f_0 + jnp.linalg.norm(g_0) / 2,
|
|
status=0,
|
|
line_search_status=0,
|
|
)
|
|
|
|
def cond_fun(state):
|
|
return (jnp.logical_not(state.converged)
|
|
& jnp.logical_not(state.failed)
|
|
& (state.k < maxiter))
|
|
|
|
def body_fun(state):
|
|
p_k = -_dot(state.H_k, state.g_k)
|
|
line_search_results = line_search(
|
|
fun,
|
|
state.x_k,
|
|
p_k,
|
|
old_fval=state.f_k,
|
|
old_old_fval=state.old_old_fval,
|
|
gfk=state.g_k,
|
|
maxiter=line_search_maxiter,
|
|
)
|
|
state = state._replace(
|
|
nfev=state.nfev + line_search_results.nfev,
|
|
ngev=state.ngev + line_search_results.ngev,
|
|
failed=line_search_results.failed,
|
|
line_search_status=line_search_results.status,
|
|
)
|
|
s_k = line_search_results.a_k * p_k
|
|
x_kp1 = state.x_k + s_k
|
|
f_kp1 = line_search_results.f_k
|
|
g_kp1 = line_search_results.g_k
|
|
y_k = g_kp1 - state.g_k
|
|
rho_k = jnp.reciprocal(_dot(y_k, s_k))
|
|
|
|
sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :]
|
|
w = jnp.eye(d) - rho_k * sy_k
|
|
H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w)
|
|
+ rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :])
|
|
H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
|
|
converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol
|
|
|
|
state = state._replace(
|
|
converged=converged,
|
|
k=state.k + 1,
|
|
x_k=x_kp1,
|
|
f_k=f_kp1,
|
|
g_k=g_kp1,
|
|
H_k=H_kp1,
|
|
old_old_fval=state.f_k,
|
|
)
|
|
return state
|
|
|
|
state = lax.while_loop(cond_fun, body_fun, state)
|
|
status = jnp.where(
|
|
state.converged,
|
|
0, # converged
|
|
jnp.where(
|
|
state.k == maxiter,
|
|
1, # max iters reached
|
|
jnp.where(
|
|
state.failed,
|
|
2 + state.line_search_status, # ls failed (+ reason)
|
|
-1, # undefined
|
|
)
|
|
)
|
|
)
|
|
state = state._replace(status=status)
|
|
return state
|