[x64] more type safety in scipy.optimize.line_search

This commit is contained in:
Jake VanderPlas 2022-12-01 14:04:39 -08:00
parent b3b7eb68f1
commit 37acc6e426
2 changed files with 8 additions and 4 deletions

View File

@ -24,13 +24,15 @@ _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
def _cubicmin(a, fa, fpa, b, fb, c, fc):
dtype = jnp.result_type(a, fa, fpa, b, fb, c, fc)
C = fpa
db = b - a
dc = c - a
denom = (db * dc) ** 2 * (db - dc)
d1 = jnp.array([[dc ** 2, -db ** 2],
[-dc ** 3, db ** 3]])
A, B = _dot(d1, jnp.array([fb - fa - C * db, fc - fa - C * dc])) / denom
[-dc ** 3, db ** 3]], dtype=dtype)
d2 = jnp.array([fb - fa - C * db, fc - fa - C * dc], dtype=dtype)
A, B = _dot(d1, d2) / denom
radical = B * B - 3. * A * C
xmin = a + (-B + jnp.sqrt(radical)) / (3. * A)

View File

@ -1,6 +1,7 @@
from absl.testing import absltest
import scipy.optimize
import jax
from jax import grad
from jax.config import config
import jax.numpy as jnp
@ -150,10 +151,11 @@ class TestLineSearch(jtu.JaxTestCase):
# assert not line_search(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed
xk = jnp.ones(2)
pk = jnp.array([-0.5, -0.25])
pk = jnp.array([-0.5, -0.25], dtype=xk.dtype)
res = line_search(f, xk, pk, maxiter=100)
scipy_res = scipy.optimize.line_search(f, grad(f), xk, pk)
with jax.numpy_dtype_promotion('standard'):
scipy_res = scipy.optimize.line_search(f, grad(f), xk, pk)
self.assertAllClose(scipy_res[0], res.a_k, atol=1e-5, check_dtypes=False)
self.assertAllClose(scipy_res[3], res.f_k, atol=1e-5, check_dtypes=False)