mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

These docstrings do not make the tests any more clear and typically just duplicate the test module name. PiperOrigin-RevId: 737611977
312 lines
10 KiB
Python
312 lines
10 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import functools
|
|
|
|
from absl.testing import absltest
|
|
import numpy as np
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax._src.test_util as jtu
|
|
from jax import jit, grad, jacfwd, jacrev
|
|
from jax import lax
|
|
from jax.example_libraries import optimizers
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
class OptimizerTests(jtu.JaxTestCase):
|
|
|
|
def _CheckOptimizer(self, optimizer, loss, x0, num_steps, *args, **kwargs):
|
|
self._CheckFuns(optimizer, loss, x0, *args)
|
|
self._CheckRun(optimizer, loss, x0, num_steps, *args, **kwargs)
|
|
|
|
def _CheckFuns(self, optimizer, loss, x0, *args):
|
|
init_fun, update_fun, get_params = optimizer(*args)
|
|
opt_state = init_fun(x0)
|
|
self.assertAllClose(x0, get_params(opt_state))
|
|
opt_state2 = update_fun(0, grad(loss)(x0), opt_state) # doesn't crash
|
|
self.assertEqual(jax.tree.structure(opt_state),
|
|
jax.tree.structure(opt_state2))
|
|
|
|
@jtu.skip_on_devices('gpu')
|
|
def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs):
|
|
init_fun, update_fun, get_params = optimizer(*args)
|
|
|
|
opt_state = init_fun(x0)
|
|
for i in range(num_steps):
|
|
x = get_params(opt_state)
|
|
g = grad(loss)(x)
|
|
opt_state = update_fun(i, g, opt_state)
|
|
xstar = get_params(opt_state)
|
|
self.assertLess(loss(xstar), 1e-2)
|
|
|
|
update_fun_jitted = jit(update_fun)
|
|
opt_state = init_fun(x0)
|
|
for i in range(num_steps):
|
|
x = get_params(opt_state)
|
|
g = grad(loss)(x)
|
|
opt_state = update_fun_jitted(i, g, opt_state)
|
|
xstar = get_params(opt_state)
|
|
self.assertLess(loss(xstar), 1e-2)
|
|
|
|
def testSgdScalar(self):
|
|
def loss(x): return x**2
|
|
x0 = 1.
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)
|
|
|
|
def testSgdVector(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)
|
|
|
|
def testSgdNestedTuple(self):
|
|
def loss(xyz):
|
|
x, (y, z) = xyz
|
|
return sum(jnp.dot(a, a) for a in [x, y, z])
|
|
x0 = (jnp.ones(2), (jnp.ones(2), jnp.ones(2)))
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)
|
|
|
|
def testMomentumVector(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
mass = 0.
|
|
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
|
|
|
|
def testMomentumDict(self):
|
|
def loss(dct): return jnp.dot(dct['x'], dct['x'])
|
|
x0 = {'x': jnp.ones(2)}
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
mass = 0.
|
|
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
|
|
|
|
def testRmspropVector(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_size)
|
|
|
|
def testAdamVector(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.adam, loss, x0, num_iters, step_size)
|
|
|
|
def testSgdClosure(self):
|
|
def loss(y, x): return y**2 * x**2
|
|
x0 = 1.
|
|
y = 1.
|
|
num_iters = 20
|
|
step_size = 0.1
|
|
partial_loss = functools.partial(loss, y)
|
|
self._CheckRun(optimizers.sgd, partial_loss, x0, num_iters, step_size)
|
|
|
|
def testAdagrad(self):
|
|
|
|
def loss(xs):
|
|
x1, x2 = xs
|
|
return jnp.sum(x1**2) + jnp.sum(x2**2)
|
|
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
x0 = (jnp.ones(2), jnp.ones((2, 2)))
|
|
self._CheckOptimizer(optimizers.adagrad, loss, x0, num_iters, step_size)
|
|
|
|
def testSM3Scalar(self):
|
|
def loss(x): return x**2
|
|
x0 = jnp.array(1.)
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)
|
|
|
|
def testSM3Vector(self):
|
|
def loss(xs):
|
|
x1, x2 = xs
|
|
return jnp.sum(x1 ** 2) + jnp.sum(x2 ** 2)
|
|
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
x0 = (jnp.ones(2), jnp.ones((2, 2)))
|
|
self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)
|
|
|
|
def testAdaMaxVector(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
num_iters = 100
|
|
step_size = 0.1
|
|
self._CheckOptimizer(optimizers.adamax, loss, x0, num_iters, step_size)
|
|
|
|
def testSgdVectorExponentialDecaySchedule(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
|
|
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
|
|
|
|
def testSgdVectorInverseTimeDecaySchedule(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
|
|
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
|
|
|
|
def testAdamVectorInverseTimeDecaySchedule(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
|
|
self._CheckFuns(optimizers.adam, loss, x0, step_schedule)
|
|
|
|
def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_sched = optimizers.inverse_time_decay(0.1, 3, 2., staircase=True)
|
|
mass = 0.9
|
|
self._CheckFuns(optimizers.momentum, loss, x0, step_sched, mass)
|
|
|
|
def testRmspropmomentumVectorPolynomialDecaySchedule(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_schedule = optimizers.polynomial_decay(1.0, 50, 0.1)
|
|
self._CheckFuns(optimizers.rmsprop_momentum, loss, x0, step_schedule)
|
|
|
|
def testRmspropVectorPiecewiseConstantSchedule(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
|
|
self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)
|
|
|
|
def testTracedStepSize(self):
|
|
def loss(x): return jnp.dot(x, x)
|
|
x0 = jnp.ones(2)
|
|
step_size = 0.1
|
|
|
|
init_fun, _, _ = optimizers.sgd(step_size)
|
|
opt_state = init_fun(x0)
|
|
|
|
@jit
|
|
def update(opt_state, step_size):
|
|
_, update_fun, get_params = optimizers.sgd(step_size)
|
|
x = get_params(opt_state)
|
|
g = grad(loss)(x)
|
|
return update_fun(0, g, opt_state)
|
|
|
|
update(opt_state, 0.9) # doesn't crash
|
|
|
|
# TODO(mattjj): re-enable
|
|
# def testDeviceTupleState(self):
|
|
# init_fun, update_fun, _ = optimizers.sgd(0.1)
|
|
# opt_state = init_fun(jnp.zeros(3))
|
|
# self.assertIsInstance(opt_state, optimizers.OptimizerState)
|
|
# self.assertIsInstance(opt_state.packed_state, core.JaxTuple)
|
|
# opt_state = jit(update_fun)(0, jnp.zeros(3), opt_state)
|
|
# self.assertIsInstance(opt_state, optimizers.OptimizerState)
|
|
# self.assertIsInstance(opt_state.packed_state, xla.DeviceTuple)
|
|
|
|
def testUpdateFunStructureMismatchErrorMessage(self):
|
|
@optimizers.optimizer
|
|
def opt_maker():
|
|
def init_fun(x0):
|
|
return {'x': x0}
|
|
def update_fun(i, g, opt_state):
|
|
x = opt_state['x']
|
|
return {'x': x - 0.1 * g, 'v': g} # bug!
|
|
def get_params(opt_state):
|
|
return opt_state['x']
|
|
return init_fun, update_fun, get_params
|
|
|
|
init_fun, update_fun, get_params = opt_maker()
|
|
opt_state = init_fun(jnp.zeros(3))
|
|
self.assertRaises(TypeError, lambda: update_fun(opt_state))
|
|
|
|
def testUtilityNorm(self):
|
|
x0 = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
|
|
norm = optimizers.l2_norm(x0)
|
|
expected = np.sqrt(np.sum(np.ones(2+3+4)**2))
|
|
self.assertAllClose(norm, expected, check_dtypes=False)
|
|
|
|
def testUtilityClipGrads(self):
|
|
g = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
|
|
norm = optimizers.l2_norm(g)
|
|
|
|
ans = optimizers.clip_grads(g, 1.1 * norm)
|
|
expected = g
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
|
|
expected = 0.9 * norm
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testIssue758(self):
|
|
# code from https://github.com/jax-ml/jax/issues/758
|
|
# this is more of a scan + jacfwd/jacrev test, but it lives here to use the
|
|
# optimizers.py code
|
|
|
|
def harmonic_bond(conf, params):
|
|
return jnp.sum(conf * params)
|
|
|
|
opt_init, opt_update, get_params = optimizers.sgd(5e-2)
|
|
|
|
x0 = np.array([0.5], dtype=np.float64)
|
|
|
|
def minimize_structure(test_params):
|
|
energy_fn = functools.partial(harmonic_bond, params=test_params)
|
|
grad_fn = grad(energy_fn, argnums=(0,))
|
|
opt_state = opt_init(x0)
|
|
|
|
def apply_carry(carry, _):
|
|
i, x = carry
|
|
g = grad_fn(get_params(x))[0]
|
|
new_state = opt_update(i, g, x)
|
|
new_carry = (i+1, new_state)
|
|
return new_carry, _
|
|
|
|
carry_final, _ = lax.scan(apply_carry, (0, opt_state), jnp.zeros((75, 0)))
|
|
trip, opt_final = carry_final
|
|
assert trip == 75
|
|
return opt_final
|
|
|
|
initial_params = jnp.array(0.5)
|
|
minimize_structure(initial_params)
|
|
|
|
def loss(test_params):
|
|
opt_final = minimize_structure(test_params)
|
|
return 1.0 - get_params(opt_final)[0]
|
|
|
|
loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2)
|
|
|
|
J1 = jacrev(loss, argnums=(0,))(initial_params)
|
|
J2 = jacfwd(loss, argnums=(0,))(initial_params)
|
|
self.assertAllClose(J1, J2, rtol=1e-6)
|
|
|
|
def testUnpackPackRoundTrip(self):
|
|
opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
|
|
params = [{'w': self.rng().randn(1, 2), 'bias': self.rng().randn(2)}]
|
|
expected = opt_init(params)
|
|
ans = optimizers.pack_optimizer_state(
|
|
optimizers.unpack_optimizer_state(expected))
|
|
self.assertEqual(ans, expected)
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|