mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Make jax.scipy.optimize test compatible with upstream scipy
This commit is contained in:
parent
7e372944f9
commit
17a26235e6
@ -90,7 +90,10 @@ class TestBFGS(jtu.JaxTestCase):
|
||||
return result.x
|
||||
|
||||
jax_res = min_op(x0)
|
||||
scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x
|
||||
# Newer scipy versions perform poorly in float32. See
|
||||
# https://github.com/scipy/scipy/issues/19024.
|
||||
x0_f64 = x0.astype('float64')
|
||||
scipy_res = scipy.optimize.minimize(func(np), x0_f64, method='BFGS').x
|
||||
self.assertAllClose(scipy_res, jax_res, atol=2e-4, rtol=2e-4,
|
||||
check_dtypes=False)
|
||||
|
||||
@ -166,10 +169,13 @@ class TestLBFGS(jtu.JaxTestCase):
|
||||
|
||||
jax_res = min_op(x0)
|
||||
|
||||
# Newer scipy versions perform poorly in float32. See
|
||||
# https://github.com/scipy/scipy/issues/19024.
|
||||
x0_f64 = x0.astype('float64')
|
||||
# Note that without bounds, L-BFGS-B is just L-BFGS
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message=".*tostring.*is deprecated.*"):
|
||||
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
|
||||
scipy_res = scipy.optimize.minimize(func(np), x0_f64, method='L-BFGS-B').x
|
||||
|
||||
if func.__name__ == 'matyas':
|
||||
# scipy performs badly for Matyas, compare to true minimum instead
|
||||
|
Loading…
x
Reference in New Issue
Block a user