mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[x64] make scipy_optimize_test compatible with strict dtype promotion
This commit is contained in:
parent
d7ec244c0c
commit
3a8f478b0a
@ -155,13 +155,13 @@ def _minimize_lbfgs(
|
||||
)
|
||||
|
||||
# evaluate at next iterate
|
||||
s_k = ls_results.a_k * p_k
|
||||
s_k = ls_results.a_k.astype(p_k.dtype) * p_k
|
||||
x_kp1 = state.x_k + s_k
|
||||
f_kp1 = ls_results.f_k
|
||||
g_kp1 = ls_results.g_k
|
||||
y_k = g_kp1 - state.g_k
|
||||
rho_k_inv = jnp.real(_dot(y_k, s_k))
|
||||
rho_k = jnp.reciprocal(rho_k_inv)
|
||||
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
|
||||
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
|
||||
|
||||
# replacements for next iteration
|
||||
@ -198,6 +198,7 @@ def _minimize_lbfgs(
|
||||
|
||||
|
||||
def _two_loop_recursion(state: LBFGSResults):
|
||||
dtype = state.rho_history.dtype
|
||||
his_size = len(state.rho_history)
|
||||
curr_size = jnp.where(state.k < his_size, state.k, his_size)
|
||||
q = -jnp.conj(state.g_k)
|
||||
@ -206,7 +207,7 @@ def _two_loop_recursion(state: LBFGSResults):
|
||||
def body_fun1(j, carry):
|
||||
i = his_size - 1 - j
|
||||
_q, _a_his = carry
|
||||
a_i = state.rho_history[i] * jnp.real(_dot(jnp.conj(state.s_history[i]), _q))
|
||||
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
|
||||
_a_his = _a_his.at[i].set(a_i)
|
||||
_q = _q - a_i * jnp.conj(state.y_history[i])
|
||||
return _q, _a_his
|
||||
@ -216,7 +217,7 @@ def _two_loop_recursion(state: LBFGSResults):
|
||||
|
||||
def body_fun2(j, _q):
|
||||
i = his_size - curr_size + j
|
||||
b_i = state.rho_history[i] * jnp.real(_dot(state.y_history[i], _q))
|
||||
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
|
||||
_q = _q + (a_his[i] - b_i) * state.s_history[i]
|
||||
return _q
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
from typing import NamedTuple, Union
|
||||
from functools import partial
|
||||
|
||||
from jax._src.numpy.util import _promote_dtypes_inexact
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -269,13 +270,15 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
|
||||
Returns: LineSearchResults
|
||||
"""
|
||||
xk, pk = _promote_dtypes_inexact(xk, pk)
|
||||
def restricted_func_and_grad(t):
|
||||
t = jnp.array(t, dtype=pk.dtype)
|
||||
phi, g = jax.value_and_grad(f)(xk + t * pk)
|
||||
dphi = jnp.real(_dot(g, pk))
|
||||
return phi, dphi, g
|
||||
|
||||
if old_fval is None or gfk is None:
|
||||
phi_0, dphi_0, gfk = restricted_func_and_grad(0.)
|
||||
phi_0, dphi_0, gfk = restricted_func_and_grad(0)
|
||||
else:
|
||||
phi_0 = old_fval
|
||||
dphi_0 = jnp.real(_dot(gfk, pk))
|
||||
|
@ -104,7 +104,7 @@ class TestBFGS(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag('jax_enable_x64', False)
|
||||
def test_zakharov(self):
|
||||
def zakharov_fn(x):
|
||||
ii = jnp.arange(1, len(x) + 1, step=1)
|
||||
ii = jnp.arange(1, len(x) + 1, step=1, dtype=x.dtype)
|
||||
answer = zakharovFromIndices(x=x, ii=ii)
|
||||
return answer
|
||||
|
||||
@ -208,8 +208,8 @@ class TestLBFGS(jtu.JaxTestCase):
|
||||
complex_dim = 5
|
||||
|
||||
f_re = rosenbrock(jnp)
|
||||
init_re = jnp.zeros((2 * complex_dim,))
|
||||
expect_re = jnp.ones((2 * complex_dim,))
|
||||
init_re = jnp.zeros((2 * complex_dim,), dtype=complex)
|
||||
expect_re = jnp.ones((2 * complex_dim,), dtype=complex)
|
||||
|
||||
def f(z):
|
||||
x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])
|
||||
|
2
tests/third_party/scipy/line_search_test.py
vendored
2
tests/third_party/scipy/line_search_test.py
vendored
@ -130,7 +130,7 @@ class TestLineSearch(jtu.JaxTestCase):
|
||||
# |x + s| <= c2 * |x|
|
||||
f = lambda x: jnp.dot(x, x)
|
||||
fp = lambda x: 2 * x
|
||||
p = jnp.array([1, 0])
|
||||
p = jnp.array([1.0, 0.0])
|
||||
|
||||
# Smallest s satisfying strong Wolfe conditions for these arguments is 30
|
||||
x = -60 * p
|
||||
|
Loading…
x
Reference in New Issue
Block a user