rocm_jax/tests/scipy_optimize_test.py

238 lines
6.9 KiB
Python
Raw Normal View History

BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest, parameterized
import numpy as np
import scipy
import scipy.optimize
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
from jax import numpy as jnp
from jax._src import test_util as jtu
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
from jax import jit
from jax.config import config
import jax.scipy.optimize
config.parse_flags_with_absl()
def rosenbrock(np):
def func(x):
return np.sum(100. * np.diff(x) ** 2 + (1. - x[:-1]) ** 2)
return func
def himmelblau(np):
def func(p):
x, y = p
return (x ** 2 + y - 11.) ** 2 + (x + y ** 2 - 7.) ** 2
return func
def matyas(np):
def func(p):
x, y = p
return 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y
return func
def eggholder(np):
def func(p):
x, y = p
return - (y + 47) * np.sin(np.sqrt(np.abs(x / 2. + y + 47.))) - x * np.sin(
np.sqrt(np.abs(x - (y + 47.))))
return func
def zakharovFromIndices(x, ii):
sum1 = (x**2).sum()
sum2 = (0.5*ii*x).sum()
answer = sum1+sum2**2+sum2**4
return answer
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
class TestBFGS(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_func={func_and_init[0].__name__}_maxiter={maxiter}",
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
"maxiter": maxiter, "func_and_init": func_and_init}
for maxiter in [None]
for func_and_init in [(rosenbrock, np.zeros(2)),
(himmelblau, np.ones(2)),
(matyas, np.ones(2) * 6.),
(eggholder, np.ones(2) * 100.)]))
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
def test_minimize(self, maxiter, func_and_init):
# Note, cannot compare step for step with scipy BFGS because our line search is _slightly_ different.
func, x0 = func_and_init
@jit
def min_op(x0):
result = jax.scipy.optimize.minimize(
func(jnp),
x0,
method='BFGS',
options=dict(maxiter=maxiter, gtol=1e-6),
)
return result.x
jax_res = min_op(x0)
scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)
def test_fixes4594(self):
n = 2
A = jnp.eye(n) * 1e4
def f(x):
return jnp.mean((A @ x) ** 2)
results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS')
self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6)
@jtu.skip_on_flag('jax_enable_x64', False)
def test_zakharov(self):
def zakharov_fn(x):
ii = jnp.arange(1, len(x) + 1, step=1, dtype=x.dtype)
answer = zakharovFromIndices(x=x, ii=ii)
return answer
x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
eval_func = jax.jit(zakharov_fn)
jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS')
self.assertLess(jax_res.fun, 1e-6)
def test_minimize_bad_initial_values(self):
# This test runs deliberately "bad" initial values to test that handling
# of failed line search, etc. is the same across implementations
initial_value = jnp.array([92, 0.001])
opt_fn = himmelblau(jnp)
jax_res = jax.scipy.optimize.minimize(
fun=opt_fn,
x0=initial_value,
method='BFGS',
).x
scipy_res = scipy.optimize.minimize(
fun=opt_fn,
jac=jax.grad(opt_fn),
method='BFGS',
x0=initial_value
).x
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)
def test_args_must_be_tuple(self):
A = jnp.eye(2) * 1e4
def f(x):
return jnp.mean((A @ x) ** 2)
with self.assertRaisesRegex(TypeError, "args .* must be a tuple"):
jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
2021-03-12 23:19:20 +01:00
class TestLBFGS(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_func={func_and_init[0].__name__}_maxiter={maxiter}",
2021-03-12 23:19:20 +01:00
"maxiter": maxiter, "func_and_init": func_and_init}
for maxiter in [None]
for func_and_init in [(rosenbrock, np.zeros(2)),
(himmelblau, np.zeros(2)),
(matyas, np.ones(2) * 6.),
(eggholder, np.ones(2) * 100.)]))
def test_minimize(self, maxiter, func_and_init):
func, x0 = func_and_init
@jit
def min_op(x0):
result = jax.scipy.optimize.minimize(
func(jnp),
x0,
method='l-bfgs-experimental-do-not-rely-on-this',
options=dict(maxiter=maxiter, gtol=1e-7),
)
return result.x
jax_res = min_op(x0)
# Note that without bounds, L-BFGS-B is just L-BFGS
with jtu.ignore_warning(category=DeprecationWarning,
message=".*tostring.*is deprecated.*"):
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
2021-03-12 23:19:20 +01:00
if func.__name__ == 'matyas':
# scipy performs badly for Matyas, compare to true minimum instead
self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7)
return
if func.__name__ == 'eggholder':
# L-BFGS performs poorly for the eggholder function.
# Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results
self.assertAllClose(jax_res, scipy_res, atol=1e-3)
return
self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)
def test_minimize_complex_sphere(self):
z0 = jnp.array([1., 2. - 3.j, 4., -5.j])
def f(z):
return jnp.real(jnp.dot(jnp.conj(z - z0), z - z0))
@jit
def min_op(x0):
result = jax.scipy.optimize.minimize(
f,
x0,
2021-06-01 12:42:55 -07:00
method='l-bfgs-experimental-do-not-rely-on-this',
2021-03-12 23:19:20 +01:00
options=dict(gtol=1e-6),
)
return result.x
jax_res = min_op(jnp.zeros_like(z0))
self.assertAllClose(jax_res, z0)
def test_complex_rosenbrock(self):
complex_dim = 5
f_re = rosenbrock(jnp)
init_re = jnp.zeros((2 * complex_dim,), dtype=complex)
expect_re = jnp.ones((2 * complex_dim,), dtype=complex)
2021-03-12 23:19:20 +01:00
def f(z):
x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])
return f_re(x_re)
init = init_re[:complex_dim] + 1.j * init_re[complex_dim:]
expect = expect_re[:complex_dim] + 1.j * expect_re[complex_dim:]
@jit
def min_op(z0):
result = jax.scipy.optimize.minimize(
f,
z0,
method='l-bfgs-experimental-do-not-rely-on-this',
options=dict(gtol=1e-6),
)
return result.x
jax_res = min_op(init)
self.assertAllClose(jax_res, expect, atol=2e-5)
BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 23:14:40 +02:00
if __name__ == "__main__":
absltest.main()