[x64] make scipy_optimize_test compatible with strict dtype promotion

This commit is contained in:
Jake VanderPlas 2022-06-17 12:46:28 -07:00
parent d7ec244c0c
commit 3a8f478b0a
4 changed files with 13 additions and 9 deletions

View File

@ -155,13 +155,13 @@ def _minimize_lbfgs(
)
# evaluate at next iterate
s_k = ls_results.a_k * p_k
s_k = ls_results.a_k.astype(p_k.dtype) * p_k
x_kp1 = state.x_k + s_k
f_kp1 = ls_results.f_k
g_kp1 = ls_results.g_k
y_k = g_kp1 - state.g_k
rho_k_inv = jnp.real(_dot(y_k, s_k))
rho_k = jnp.reciprocal(rho_k_inv)
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
# replacements for next iteration
@ -198,6 +198,7 @@ def _minimize_lbfgs(
def _two_loop_recursion(state: LBFGSResults):
dtype = state.rho_history.dtype
his_size = len(state.rho_history)
curr_size = jnp.where(state.k < his_size, state.k, his_size)
q = -jnp.conj(state.g_k)
@ -206,7 +207,7 @@ def _two_loop_recursion(state: LBFGSResults):
def body_fun1(j, carry):
i = his_size - 1 - j
_q, _a_his = carry
a_i = state.rho_history[i] * jnp.real(_dot(jnp.conj(state.s_history[i]), _q))
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
_a_his = _a_his.at[i].set(a_i)
_q = _q - a_i * jnp.conj(state.y_history[i])
return _q, _a_his
@ -216,7 +217,7 @@ def _two_loop_recursion(state: LBFGSResults):
def body_fun2(j, _q):
i = his_size - curr_size + j
b_i = state.rho_history[i] * jnp.real(_dot(state.y_history[i], _q))
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
_q = _q + (a_his[i] - b_i) * state.s_history[i]
return _q

View File

@ -15,6 +15,7 @@
from typing import NamedTuple, Union
from functools import partial
from jax._src.numpy.util import _promote_dtypes_inexact
import jax.numpy as jnp
import jax
from jax import lax
@ -269,13 +270,15 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
Returns: LineSearchResults
"""
xk, pk = _promote_dtypes_inexact(xk, pk)
def restricted_func_and_grad(t):
t = jnp.array(t, dtype=pk.dtype)
phi, g = jax.value_and_grad(f)(xk + t * pk)
dphi = jnp.real(_dot(g, pk))
return phi, dphi, g
if old_fval is None or gfk is None:
phi_0, dphi_0, gfk = restricted_func_and_grad(0.)
phi_0, dphi_0, gfk = restricted_func_and_grad(0)
else:
phi_0 = old_fval
dphi_0 = jnp.real(_dot(gfk, pk))

View File

@ -104,7 +104,7 @@ class TestBFGS(jtu.JaxTestCase):
@jtu.skip_on_flag('jax_enable_x64', False)
def test_zakharov(self):
def zakharov_fn(x):
ii = jnp.arange(1, len(x) + 1, step=1)
ii = jnp.arange(1, len(x) + 1, step=1, dtype=x.dtype)
answer = zakharovFromIndices(x=x, ii=ii)
return answer
@ -208,8 +208,8 @@ class TestLBFGS(jtu.JaxTestCase):
complex_dim = 5
f_re = rosenbrock(jnp)
init_re = jnp.zeros((2 * complex_dim,))
expect_re = jnp.ones((2 * complex_dim,))
init_re = jnp.zeros((2 * complex_dim,), dtype=complex)
expect_re = jnp.ones((2 * complex_dim,), dtype=complex)
def f(z):
x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])

View File

@ -130,7 +130,7 @@ class TestLineSearch(jtu.JaxTestCase):
# |x + s| <= c2 * |x|
f = lambda x: jnp.dot(x, x)
fp = lambda x: 2 * x
p = jnp.array([1, 0])
p = jnp.array([1.0, 0.0])
# Smallest s satisfying strong Wolfe conditions for these arguments is 30
x = -60 * p