Make jax.scipy.optimize test compatible with upstream scipy

This commit is contained in:
Jake VanderPlas 2023-11-06 14:10:24 -08:00
parent 7e372944f9
commit 17a26235e6

View File

@ -90,7 +90,10 @@ class TestBFGS(jtu.JaxTestCase):
return result.x
jax_res = min_op(x0)
scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x
# Newer scipy versions perform poorly in float32. See
# https://github.com/scipy/scipy/issues/19024.
x0_f64 = x0.astype('float64')
scipy_res = scipy.optimize.minimize(func(np), x0_f64, method='BFGS').x
self.assertAllClose(scipy_res, jax_res, atol=2e-4, rtol=2e-4,
check_dtypes=False)
@ -166,10 +169,13 @@ class TestLBFGS(jtu.JaxTestCase):
jax_res = min_op(x0)
# Newer scipy versions perform poorly in float32. See
# https://github.com/scipy/scipy/issues/19024.
x0_f64 = x0.astype('float64')
# Note that without bounds, L-BFGS-B is just L-BFGS
with jtu.ignore_warning(category=DeprecationWarning,
message=".*tostring.*is deprecated.*"):
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
scipy_res = scipy.optimize.minimize(func(np), x0_f64, method='L-BFGS-B').x
if func.__name__ == 'matyas':
# scipy performs badly for Matyas, compare to true minimum instead