mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[x64] more type safety in scipy.optimize.line_search
This commit is contained in:
parent
b3b7eb68f1
commit
37acc6e426
@ -24,13 +24,15 @@ _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
def _cubicmin(a, fa, fpa, b, fb, c, fc):
|
||||
dtype = jnp.result_type(a, fa, fpa, b, fb, c, fc)
|
||||
C = fpa
|
||||
db = b - a
|
||||
dc = c - a
|
||||
denom = (db * dc) ** 2 * (db - dc)
|
||||
d1 = jnp.array([[dc ** 2, -db ** 2],
|
||||
[-dc ** 3, db ** 3]])
|
||||
A, B = _dot(d1, jnp.array([fb - fa - C * db, fc - fa - C * dc])) / denom
|
||||
[-dc ** 3, db ** 3]], dtype=dtype)
|
||||
d2 = jnp.array([fb - fa - C * db, fc - fa - C * dc], dtype=dtype)
|
||||
A, B = _dot(d1, d2) / denom
|
||||
|
||||
radical = B * B - 3. * A * C
|
||||
xmin = a + (-B + jnp.sqrt(radical)) / (3. * A)
|
||||
|
6
tests/third_party/scipy/line_search_test.py
vendored
6
tests/third_party/scipy/line_search_test.py
vendored
@ -1,6 +1,7 @@
|
||||
from absl.testing import absltest
|
||||
import scipy.optimize
|
||||
|
||||
import jax
|
||||
from jax import grad
|
||||
from jax.config import config
|
||||
import jax.numpy as jnp
|
||||
@ -150,10 +151,11 @@ class TestLineSearch(jtu.JaxTestCase):
|
||||
|
||||
# assert not line_search(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed
|
||||
xk = jnp.ones(2)
|
||||
pk = jnp.array([-0.5, -0.25])
|
||||
pk = jnp.array([-0.5, -0.25], dtype=xk.dtype)
|
||||
res = line_search(f, xk, pk, maxiter=100)
|
||||
|
||||
scipy_res = scipy.optimize.line_search(f, grad(f), xk, pk)
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
scipy_res = scipy.optimize.line_search(f, grad(f), xk, pk)
|
||||
|
||||
self.assertAllClose(scipy_res[0], res.a_k, atol=1e-5, check_dtypes=False)
|
||||
self.assertAllClose(scipy_res[3], res.f_k, atol=1e-5, check_dtypes=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user