rocm_jax/tests/scipy_optimize_test.py
Peter Hawkins 291e52a713 Fix some warnings causing CI failures on ARM.
PiperOrigin-RevId: 678454816
2024-09-24 17:25:26 -07:00

244 lines
6.9 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import numpy as np
import scipy
import scipy.optimize
import jax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import jit
import jax.scipy.optimize
jax.config.parse_flags_with_absl()
def rosenbrock(np):
def func(x):
return np.sum(100. * np.diff(x) ** 2 + (1. - x[:-1]) ** 2)
return func
def himmelblau(np):
def func(p):
x, y = p
return (x ** 2 + y - 11.) ** 2 + (x + y ** 2 - 7.) ** 2
return func
def matyas(np):
def func(p):
x, y = p
return 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y
return func
def eggholder(np):
def func(p):
x, y = p
return - (y + 47) * np.sin(np.sqrt(np.abs(x / 2. + y + 47.))) - x * np.sin(
np.sqrt(np.abs(x - (y + 47.))))
return func
def zakharovFromIndices(x, ii):
sum1 = (x**2).sum()
sum2 = (0.5*ii*x).sum()
answer = sum1+sum2**2+sum2**4
return answer
class TestBFGS(jtu.JaxTestCase):
@jtu.sample_product(
maxiter=[None],
func_and_init=[(rosenbrock, np.zeros(2, dtype='float32')),
(himmelblau, np.ones(2, dtype='float32')),
(matyas, np.ones(2) * 6.),
(eggholder, np.ones(2) * 100.)],
)
def test_minimize(self, maxiter, func_and_init):
# Note, cannot compare step for step with scipy BFGS because our line search is _slightly_ different.
func, x0 = func_and_init
@jit
def min_op(x0):
result = jax.scipy.optimize.minimize(
func(jnp),
x0,
method='BFGS',
options=dict(maxiter=maxiter, gtol=1e-6),
)
return result.x
jax_res = min_op(x0)
# Newer scipy versions perform poorly in float32. See
# https://github.com/scipy/scipy/issues/19024.
x0_f64 = x0.astype('float64')
scipy_res = scipy.optimize.minimize(func(np), x0_f64, method='BFGS').x
self.assertAllClose(scipy_res, jax_res, atol=2e-4, rtol=2e-4,
check_dtypes=False)
def test_fixes4594(self):
n = 2
A = jnp.eye(n) * 1e4
def f(x):
return jnp.mean((A @ x) ** 2)
results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS')
self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6)
@jtu.skip_on_flag('jax_enable_x64', False)
def test_zakharov(self):
def zakharov_fn(x):
ii = jnp.arange(1, len(x) + 1, step=1, dtype=x.dtype)
answer = zakharovFromIndices(x=x, ii=ii)
return answer
x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
eval_func = jax.jit(zakharov_fn)
jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS')
self.assertLess(jax_res.fun, 1e-6)
@jtu.ignore_warning(category=RuntimeWarning, message='divide by zero')
def test_minimize_bad_initial_values(self):
# This test runs deliberately "bad" initial values to test that handling
# of failed line search, etc. is the same across implementations
initial_value = jnp.array([92, 0.001])
opt_fn = himmelblau(jnp)
jax_res = jax.scipy.optimize.minimize(
fun=opt_fn,
x0=initial_value,
method='BFGS',
).x
scipy_res = scipy.optimize.minimize(
fun=opt_fn,
jac=jax.grad(opt_fn),
method='BFGS',
x0=initial_value
).x
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)
def test_args_must_be_tuple(self):
A = jnp.eye(2) * 1e4
def f(x):
return jnp.mean((A @ x) ** 2)
with self.assertRaisesRegex(TypeError, "args .* must be a tuple"):
jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
class TestLBFGS(jtu.JaxTestCase):
@jtu.sample_product(
maxiter=[None],
func_and_init=[(rosenbrock, np.zeros(2)),
(himmelblau, np.zeros(2)),
(matyas, np.ones(2) * 6.),
(eggholder, np.ones(2) * 100.)],
)
def test_minimize(self, maxiter, func_and_init):
func, x0 = func_and_init
@jit
def min_op(x0):
result = jax.scipy.optimize.minimize(
func(jnp),
x0,
method='l-bfgs-experimental-do-not-rely-on-this',
options=dict(maxiter=maxiter, gtol=1e-7),
)
return result.x
jax_res = min_op(x0)
# Newer scipy versions perform poorly in float32. See
# https://github.com/scipy/scipy/issues/19024.
x0_f64 = x0.astype('float64')
# Note that without bounds, L-BFGS-B is just L-BFGS
with jtu.ignore_warning(category=DeprecationWarning,
message=".*tostring.*is deprecated.*"):
scipy_res = scipy.optimize.minimize(func(np), x0_f64, method='L-BFGS-B').x
if func.__name__ == 'matyas':
# scipy performs badly for Matyas, compare to true minimum instead
self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7)
return
if func.__name__ == 'eggholder':
# L-BFGS performs poorly for the eggholder function.
# Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results
self.assertAllClose(jax_res, scipy_res, atol=1e-3)
return
self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)
def test_minimize_complex_sphere(self):
z0 = jnp.array([1., 2. - 3.j, 4., -5.j])
def f(z):
return jnp.real(jnp.dot(jnp.conj(z - z0), z - z0))
@jit
def min_op(x0):
result = jax.scipy.optimize.minimize(
f,
x0,
method='l-bfgs-experimental-do-not-rely-on-this',
options=dict(gtol=1e-6),
)
return result.x
jax_res = min_op(jnp.zeros_like(z0))
self.assertAllClose(jax_res, z0)
def test_complex_rosenbrock(self):
complex_dim = 5
f_re = rosenbrock(jnp)
init_re = jnp.zeros((2 * complex_dim,), dtype=complex)
expect_re = jnp.ones((2 * complex_dim,), dtype=complex)
def f(z):
x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])
return f_re(x_re)
init = init_re[:complex_dim] + 1.j * init_re[complex_dim:]
expect = expect_re[:complex_dim] + 1.j * expect_re[complex_dim:]
@jit
def min_op(z0):
result = jax.scipy.optimize.minimize(
f,
z0,
method='l-bfgs-experimental-do-not-rely-on-this',
options=dict(gtol=1e-6),
)
return result.x
jax_res = min_op(init)
self.assertAllClose(jax_res, expect, atol=2e-5)
if __name__ == "__main__":
absltest.main()