mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update Jax linesearch to behave more like Scipy
This commit is contained in:
parent
62603fde67
commit
03a1ee9269
@ -55,6 +55,7 @@ class _BFGSResults(NamedTuple):
|
||||
f_k: jnp.ndarray
|
||||
g_k: jnp.ndarray
|
||||
H_k: jnp.ndarray
|
||||
old_old_fval: jnp.ndarray
|
||||
status: Union[int, jnp.ndarray]
|
||||
line_search_status: Union[int, jnp.ndarray]
|
||||
|
||||
@ -108,6 +109,7 @@ def minimize_bfgs(
|
||||
f_k=f_0,
|
||||
g_k=g_0,
|
||||
H_k=initial_H,
|
||||
old_old_fval=f_0 + jnp.linalg.norm(g_0) / 2,
|
||||
status=0,
|
||||
line_search_status=0,
|
||||
)
|
||||
@ -124,6 +126,7 @@ def minimize_bfgs(
|
||||
state.x_k,
|
||||
p_k,
|
||||
old_fval=state.f_k,
|
||||
old_old_fval=state.old_old_fval,
|
||||
gfk=state.g_k,
|
||||
maxiter=line_search_maxiter,
|
||||
)
|
||||
@ -153,7 +156,8 @@ def minimize_bfgs(
|
||||
x_k=x_kp1,
|
||||
f_k=f_kp1,
|
||||
g_k=g_kp1,
|
||||
H_k=H_kp1
|
||||
H_k=H_kp1,
|
||||
old_old_fval=state.f_k,
|
||||
)
|
||||
return state
|
||||
|
||||
|
@ -104,15 +104,16 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
def body(state):
|
||||
# Body of zoom algorithm. We use boolean arithmetic to avoid using jax.cond
|
||||
# so that it works on GPU/TPU.
|
||||
dalpha = (state.a_hi - state.a_lo)
|
||||
a = jnp.minimum(state.a_hi, state.a_lo)
|
||||
b = jnp.maximum(state.a_hi, state.a_lo)
|
||||
dalpha = (b - a)
|
||||
cchk = delta1 * dalpha
|
||||
qchk = delta2 * dalpha
|
||||
|
||||
# This will cause the line search to stop, and since the Wolfe conditions
|
||||
# are not satisfied the minimization should stop too.
|
||||
state = state._replace(failed=state.failed | (dalpha <= 1e-10))
|
||||
threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10)
|
||||
state = state._replace(failed=state.failed | (dalpha <= threshold))
|
||||
|
||||
# Cubmin is sometimes nan, though in this case the bounds check will fail.
|
||||
a_j_cubic = _cubicmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi,
|
||||
@ -169,9 +170,9 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
hi_to_lo,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_hi=a_lo,
|
||||
phi_hi=phi_lo,
|
||||
dphi_hi=dphi_lo,
|
||||
a_hi=state.a_lo,
|
||||
phi_hi=state.phi_lo,
|
||||
dphi_hi=state.dphi_lo,
|
||||
a_rec=state.a_hi,
|
||||
phi_rec=state.phi_hi,
|
||||
),
|
||||
@ -191,6 +192,9 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
),
|
||||
)
|
||||
state = state._replace(j=state.j + 1)
|
||||
# Choose higher cutoff for maxiter than Scipy as Jax takes longer to find
|
||||
# the same value - possibly floating point issues?
|
||||
state = state._replace(failed= state.failed | state.j >= 30)
|
||||
return state
|
||||
|
||||
state = while_loop(lambda state: (~state.done) & (~pass_through) & (~state.failed),
|
||||
@ -213,7 +217,6 @@ class _LineSearchState(NamedTuple):
|
||||
phi_star: Union[float, jnp.ndarray]
|
||||
dphi_star: Union[float, jnp.ndarray]
|
||||
g_star: jnp.ndarray
|
||||
saddle_point: Union[bool, jnp.ndarray]
|
||||
|
||||
|
||||
class _LineSearchResults(NamedTuple):
|
||||
@ -269,6 +272,11 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
else:
|
||||
phi_0 = old_fval
|
||||
dphi_0 = jnp.dot(gfk, pk)
|
||||
if old_old_fval is not None:
|
||||
candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0
|
||||
start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value)
|
||||
else:
|
||||
start_value = 1
|
||||
|
||||
def wolfe_one(a_i, phi_i):
|
||||
# actually negation of W1
|
||||
@ -292,18 +300,12 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
phi_star=phi_0,
|
||||
dphi_star=dphi_0,
|
||||
g_star=gfk,
|
||||
saddle_point=False,
|
||||
)
|
||||
|
||||
def body(state):
|
||||
# no amax in this version, we just double as in scipy.
|
||||
# unlike original algorithm we do our next choice at the start of this loop
|
||||
a_i = jnp.where(state.i == 1, 1., state.a_i1 * 2.)
|
||||
# if a_i <= 0 then something went wrong. In practice any really small step
|
||||
# length is a failure. Likely means the search pk is not good, perhaps we
|
||||
# are at a saddle point.
|
||||
saddle_point = a_i < 1e-5
|
||||
state = state._replace(failed=saddle_point, saddle_point=saddle_point)
|
||||
a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.)
|
||||
|
||||
phi_i, dphi_i, g_i = restricted_func_and_grad(a_i)
|
||||
state = state._replace(nfev=state.nfev + 1,
|
||||
@ -384,25 +386,28 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
state)
|
||||
|
||||
status = jnp.where(
|
||||
state.failed & (~state.saddle_point),
|
||||
state.failed,
|
||||
jnp.array(1), # zoom failed
|
||||
jnp.where(
|
||||
state.failed & state.saddle_point,
|
||||
jnp.array(2), # saddle point reached,
|
||||
jnp.where(
|
||||
state.i > maxiter,
|
||||
jnp.array(3), # maxiter reached
|
||||
jnp.array(0), # passed (should be)
|
||||
),
|
||||
),
|
||||
)
|
||||
# Step sizes which are too small causes the optimizer to get stuck with a
|
||||
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
|
||||
alpha_k = state.a_star
|
||||
alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64)
|
||||
& (jnp.abs(alpha_k) < 1e-8),
|
||||
jnp.sign(alpha_k) * 1e-8,
|
||||
alpha_k)
|
||||
results = _LineSearchResults(
|
||||
failed=state.failed | (~state.done),
|
||||
nit=state.i - 1, # because iterations started at 1
|
||||
nfev=state.nfev,
|
||||
ngev=state.ngev,
|
||||
k=state.i,
|
||||
a_k=state.a_star,
|
||||
a_k=alpha_k,
|
||||
f_k=state.phi_star,
|
||||
g_k=state.g_star,
|
||||
status=status,
|
||||
|
@ -57,6 +57,13 @@ def eggholder(np):
|
||||
return func
|
||||
|
||||
|
||||
def zakharovFromIndices(x, ii):
|
||||
sum1 = (x**2).sum()
|
||||
sum2 = (0.5*ii*x).sum()
|
||||
answer = sum1+sum2**2+sum2**4
|
||||
return answer
|
||||
|
||||
|
||||
class TestBFGS(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -94,6 +101,37 @@ class TestBFGS(jtu.JaxTestCase):
|
||||
results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS')
|
||||
self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6)
|
||||
|
||||
@jtu.skip_on_flag('jax_enable_x64', False)
|
||||
def test_zakharov(self):
|
||||
def zakharov_fn(x):
|
||||
ii = jnp.arange(1, len(x) + 1, step=1)
|
||||
answer = zakharovFromIndices(x=x, ii=ii)
|
||||
return answer
|
||||
|
||||
x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
|
||||
eval_func = jax.jit(zakharov_fn)
|
||||
jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS')
|
||||
self.assertLess(jax_res.fun, 1e-6)
|
||||
|
||||
def test_minimize_bad_initial_values(self):
|
||||
# This test runs deliberately "bad" initial values to test that handling
|
||||
# of failed line search, etc. is the same across implementations
|
||||
initial_value = jnp.array([92, 0.001])
|
||||
opt_fn = himmelblau(jnp)
|
||||
jax_res = jax.scipy.optimize.minimize(
|
||||
fun=opt_fn,
|
||||
x0=initial_value,
|
||||
method='BFGS',
|
||||
).x
|
||||
scipy_res = scipy.optimize.minimize(
|
||||
fun=opt_fn,
|
||||
jac=jax.grad(opt_fn),
|
||||
method='BFGS',
|
||||
x0=initial_value
|
||||
).x
|
||||
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)
|
||||
|
||||
|
||||
def test_args_must_be_tuple(self):
|
||||
A = jnp.eye(2) * 1e4
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user