Update Jax linesearch to behave more like Scipy

This commit is contained in:
Gregory Thornton 2021-05-13 12:08:06 +01:00 committed by GregThornton
parent 62603fde67
commit 03a1ee9269
3 changed files with 67 additions and 20 deletions

View File

@ -55,6 +55,7 @@ class _BFGSResults(NamedTuple):
f_k: jnp.ndarray
g_k: jnp.ndarray
H_k: jnp.ndarray
old_old_fval: jnp.ndarray
status: Union[int, jnp.ndarray]
line_search_status: Union[int, jnp.ndarray]
@ -108,6 +109,7 @@ def minimize_bfgs(
f_k=f_0,
g_k=g_0,
H_k=initial_H,
old_old_fval=f_0 + jnp.linalg.norm(g_0) / 2,
status=0,
line_search_status=0,
)
@ -124,6 +126,7 @@ def minimize_bfgs(
state.x_k,
p_k,
old_fval=state.f_k,
old_old_fval=state.old_old_fval,
gfk=state.g_k,
maxiter=line_search_maxiter,
)
@ -153,7 +156,8 @@ def minimize_bfgs(
x_k=x_kp1,
f_k=f_kp1,
g_k=g_kp1,
H_k=H_kp1
H_k=H_kp1,
old_old_fval=state.f_k,
)
return state

View File

@ -104,15 +104,16 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
def body(state):
# Body of zoom algorithm. We use boolean arithmetic to avoid using jax.cond
# so that it works on GPU/TPU.
dalpha = (state.a_hi - state.a_lo)
a = jnp.minimum(state.a_hi, state.a_lo)
b = jnp.maximum(state.a_hi, state.a_lo)
dalpha = (b - a)
cchk = delta1 * dalpha
qchk = delta2 * dalpha
# This will cause the line search to stop, and since the Wolfe conditions
# are not satisfied the minimization should stop too.
state = state._replace(failed=state.failed | (dalpha <= 1e-10))
threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10)
state = state._replace(failed=state.failed | (dalpha <= threshold))
# Cubmin is sometimes nan, though in this case the bounds check will fail.
a_j_cubic = _cubicmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi,
@ -169,9 +170,9 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
hi_to_lo,
state._asdict(),
dict(
a_hi=a_lo,
phi_hi=phi_lo,
dphi_hi=dphi_lo,
a_hi=state.a_lo,
phi_hi=state.phi_lo,
dphi_hi=state.dphi_lo,
a_rec=state.a_hi,
phi_rec=state.phi_hi,
),
@ -191,6 +192,9 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
),
)
state = state._replace(j=state.j + 1)
# Choose higher cutoff for maxiter than Scipy as Jax takes longer to find
# the same value - possibly floating point issues?
state = state._replace(failed= state.failed | state.j >= 30)
return state
state = while_loop(lambda state: (~state.done) & (~pass_through) & (~state.failed),
@ -213,7 +217,6 @@ class _LineSearchState(NamedTuple):
phi_star: Union[float, jnp.ndarray]
dphi_star: Union[float, jnp.ndarray]
g_star: jnp.ndarray
saddle_point: Union[bool, jnp.ndarray]
class _LineSearchResults(NamedTuple):
@ -269,6 +272,11 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
else:
phi_0 = old_fval
dphi_0 = jnp.dot(gfk, pk)
if old_old_fval is not None:
candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0
start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value)
else:
start_value = 1
def wolfe_one(a_i, phi_i):
# actually negation of W1
@ -292,18 +300,12 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
phi_star=phi_0,
dphi_star=dphi_0,
g_star=gfk,
saddle_point=False,
)
def body(state):
# no amax in this version, we just double as in scipy.
# unlike original algorithm we do our next choice at the start of this loop
a_i = jnp.where(state.i == 1, 1., state.a_i1 * 2.)
# if a_i <= 0 then something went wrong. In practice any really small step
# length is a failure. Likely means the search pk is not good, perhaps we
# are at a saddle point.
saddle_point = a_i < 1e-5
state = state._replace(failed=saddle_point, saddle_point=saddle_point)
a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.)
phi_i, dphi_i, g_i = restricted_func_and_grad(a_i)
state = state._replace(nfev=state.nfev + 1,
@ -384,25 +386,28 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
state)
status = jnp.where(
state.failed & (~state.saddle_point),
state.failed,
jnp.array(1), # zoom failed
jnp.where(
state.failed & state.saddle_point,
jnp.array(2), # saddle point reached,
jnp.where(
state.i > maxiter,
jnp.array(3), # maxiter reached
jnp.array(0), # passed (should be)
),
),
)
# Step sizes which are too small causes the optimizer to get stuck with a
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
alpha_k = state.a_star
alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64)
& (jnp.abs(alpha_k) < 1e-8),
jnp.sign(alpha_k) * 1e-8,
alpha_k)
results = _LineSearchResults(
failed=state.failed | (~state.done),
nit=state.i - 1, # because iterations started at 1
nfev=state.nfev,
ngev=state.ngev,
k=state.i,
a_k=state.a_star,
a_k=alpha_k,
f_k=state.phi_star,
g_k=state.g_star,
status=status,

View File

@ -57,6 +57,13 @@ def eggholder(np):
return func
def zakharovFromIndices(x, ii):
sum1 = (x**2).sum()
sum2 = (0.5*ii*x).sum()
answer = sum1+sum2**2+sum2**4
return answer
class TestBFGS(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
@ -94,6 +101,37 @@ class TestBFGS(jtu.JaxTestCase):
results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS')
self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6)
@jtu.skip_on_flag('jax_enable_x64', False)
def test_zakharov(self):
def zakharov_fn(x):
ii = jnp.arange(1, len(x) + 1, step=1)
answer = zakharovFromIndices(x=x, ii=ii)
return answer
x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
eval_func = jax.jit(zakharov_fn)
jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS')
self.assertLess(jax_res.fun, 1e-6)
def test_minimize_bad_initial_values(self):
# This test runs deliberately "bad" initial values to test that handling
# of failed line search, etc. is the same across implementations
initial_value = jnp.array([92, 0.001])
opt_fn = himmelblau(jnp)
jax_res = jax.scipy.optimize.minimize(
fun=opt_fn,
x0=initial_value,
method='BFGS',
).x
scipy_res = scipy.optimize.minimize(
fun=opt_fn,
jac=jax.grad(opt_fn),
method='BFGS',
x0=initial_value
).x
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)
def test_args_must_be_tuple(self):
A = jnp.eye(2) * 1e4
def f(x):