mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add rmsprop_momentum optimizer (same as TF RMSProp)
The TensorFlow RMSProp optimizer supports an additional momentum parameter, which allows adding momentum to the RMSProp update. Having momentum requires keeping around additional state, which might not be desirable when using the standard RMSProp optimizer, so I've created an additional optimizer for this case. RMSProp with momentum can be necessary to reproduce some research papers.
This commit is contained in:
parent
c41677fac7
commit
628f87d365
@ -269,6 +269,8 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
@ -287,6 +289,41 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
|
||||
"""Construct optimizer triple for RMSProp with momentum.
|
||||
|
||||
This optimizer is separate from the rmsprop optimizer because it needs to
|
||||
keep track of additional parameters.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
momentum: Momentum parameter.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = optimizers.make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = np.zeros_like(x0)
|
||||
mom = np.zeros_like(x0)
|
||||
return x0, avg_sq_grad, mom
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad, mom = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
|
||||
mom = momentum * mom + step_size(i) * g / np.sqrt(avg_sq_grad + eps)
|
||||
x = x - mom
|
||||
return x, avg_sq_grad, mom
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""Construct optimizer triple for Adam.
|
||||
|
Loading…
x
Reference in New Issue
Block a user