rocm_jax/tests/optimizers_test.py
Sergei Lebedev 0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00

312 lines
10 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax import jit, grad, jacfwd, jacrev
from jax import lax
from jax.example_libraries import optimizers
jax.config.parse_flags_with_absl()
class OptimizerTests(jtu.JaxTestCase):
def _CheckOptimizer(self, optimizer, loss, x0, num_steps, *args, **kwargs):
self._CheckFuns(optimizer, loss, x0, *args)
self._CheckRun(optimizer, loss, x0, num_steps, *args, **kwargs)
def _CheckFuns(self, optimizer, loss, x0, *args):
init_fun, update_fun, get_params = optimizer(*args)
opt_state = init_fun(x0)
self.assertAllClose(x0, get_params(opt_state))
opt_state2 = update_fun(0, grad(loss)(x0), opt_state) # doesn't crash
self.assertEqual(jax.tree.structure(opt_state),
jax.tree.structure(opt_state2))
@jtu.skip_on_devices('gpu')
def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs):
init_fun, update_fun, get_params = optimizer(*args)
opt_state = init_fun(x0)
for i in range(num_steps):
x = get_params(opt_state)
g = grad(loss)(x)
opt_state = update_fun(i, g, opt_state)
xstar = get_params(opt_state)
self.assertLess(loss(xstar), 1e-2)
update_fun_jitted = jit(update_fun)
opt_state = init_fun(x0)
for i in range(num_steps):
x = get_params(opt_state)
g = grad(loss)(x)
opt_state = update_fun_jitted(i, g, opt_state)
xstar = get_params(opt_state)
self.assertLess(loss(xstar), 1e-2)
def testSgdScalar(self):
def loss(x): return x**2
x0 = 1.
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)
def testSgdVector(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)
def testSgdNestedTuple(self):
def loss(xyz):
x, (y, z) = xyz
return sum(jnp.dot(a, a) for a in [x, y, z])
x0 = (jnp.ones(2), (jnp.ones(2), jnp.ones(2)))
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)
def testMomentumVector(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
num_iters = 100
step_size = 0.1
mass = 0.
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
def testMomentumDict(self):
def loss(dct): return jnp.dot(dct['x'], dct['x'])
x0 = {'x': jnp.ones(2)}
num_iters = 100
step_size = 0.1
mass = 0.
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
def testRmspropVector(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_size)
def testAdamVector(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.adam, loss, x0, num_iters, step_size)
def testSgdClosure(self):
def loss(y, x): return y**2 * x**2
x0 = 1.
y = 1.
num_iters = 20
step_size = 0.1
partial_loss = functools.partial(loss, y)
self._CheckRun(optimizers.sgd, partial_loss, x0, num_iters, step_size)
def testAdagrad(self):
def loss(xs):
x1, x2 = xs
return jnp.sum(x1**2) + jnp.sum(x2**2)
num_iters = 100
step_size = 0.1
x0 = (jnp.ones(2), jnp.ones((2, 2)))
self._CheckOptimizer(optimizers.adagrad, loss, x0, num_iters, step_size)
def testSM3Scalar(self):
def loss(x): return x**2
x0 = jnp.array(1.)
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)
def testSM3Vector(self):
def loss(xs):
x1, x2 = xs
return jnp.sum(x1 ** 2) + jnp.sum(x2 ** 2)
num_iters = 100
step_size = 0.1
x0 = (jnp.ones(2), jnp.ones((2, 2)))
self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)
def testAdaMaxVector(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.adamax, loss, x0, num_iters, step_size)
def testSgdVectorExponentialDecaySchedule(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
def testSgdVectorInverseTimeDecaySchedule(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
def testAdamVectorInverseTimeDecaySchedule(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
self._CheckFuns(optimizers.adam, loss, x0, step_schedule)
def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_sched = optimizers.inverse_time_decay(0.1, 3, 2., staircase=True)
mass = 0.9
self._CheckFuns(optimizers.momentum, loss, x0, step_sched, mass)
def testRmspropmomentumVectorPolynomialDecaySchedule(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_schedule = optimizers.polynomial_decay(1.0, 50, 0.1)
self._CheckFuns(optimizers.rmsprop_momentum, loss, x0, step_schedule)
def testRmspropVectorPiecewiseConstantSchedule(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)
def testTracedStepSize(self):
def loss(x): return jnp.dot(x, x)
x0 = jnp.ones(2)
step_size = 0.1
init_fun, _, _ = optimizers.sgd(step_size)
opt_state = init_fun(x0)
@jit
def update(opt_state, step_size):
_, update_fun, get_params = optimizers.sgd(step_size)
x = get_params(opt_state)
g = grad(loss)(x)
return update_fun(0, g, opt_state)
update(opt_state, 0.9) # doesn't crash
# TODO(mattjj): re-enable
# def testDeviceTupleState(self):
# init_fun, update_fun, _ = optimizers.sgd(0.1)
# opt_state = init_fun(jnp.zeros(3))
# self.assertIsInstance(opt_state, optimizers.OptimizerState)
# self.assertIsInstance(opt_state.packed_state, core.JaxTuple)
# opt_state = jit(update_fun)(0, jnp.zeros(3), opt_state)
# self.assertIsInstance(opt_state, optimizers.OptimizerState)
# self.assertIsInstance(opt_state.packed_state, xla.DeviceTuple)
def testUpdateFunStructureMismatchErrorMessage(self):
@optimizers.optimizer
def opt_maker():
def init_fun(x0):
return {'x': x0}
def update_fun(i, g, opt_state):
x = opt_state['x']
return {'x': x - 0.1 * g, 'v': g} # bug!
def get_params(opt_state):
return opt_state['x']
return init_fun, update_fun, get_params
init_fun, update_fun, get_params = opt_maker()
opt_state = init_fun(jnp.zeros(3))
self.assertRaises(TypeError, lambda: update_fun(opt_state))
def testUtilityNorm(self):
x0 = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
norm = optimizers.l2_norm(x0)
expected = np.sqrt(np.sum(np.ones(2+3+4)**2))
self.assertAllClose(norm, expected, check_dtypes=False)
def testUtilityClipGrads(self):
g = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
norm = optimizers.l2_norm(g)
ans = optimizers.clip_grads(g, 1.1 * norm)
expected = g
self.assertAllClose(ans, expected, check_dtypes=False)
ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
expected = 0.9 * norm
self.assertAllClose(ans, expected, check_dtypes=False)
def testIssue758(self):
# code from https://github.com/jax-ml/jax/issues/758
# this is more of a scan + jacfwd/jacrev test, but it lives here to use the
# optimizers.py code
def harmonic_bond(conf, params):
return jnp.sum(conf * params)
opt_init, opt_update, get_params = optimizers.sgd(5e-2)
x0 = np.array([0.5], dtype=np.float64)
def minimize_structure(test_params):
energy_fn = functools.partial(harmonic_bond, params=test_params)
grad_fn = grad(energy_fn, argnums=(0,))
opt_state = opt_init(x0)
def apply_carry(carry, _):
i, x = carry
g = grad_fn(get_params(x))[0]
new_state = opt_update(i, g, x)
new_carry = (i+1, new_state)
return new_carry, _
carry_final, _ = lax.scan(apply_carry, (0, opt_state), jnp.zeros((75, 0)))
trip, opt_final = carry_final
assert trip == 75
return opt_final
initial_params = jnp.array(0.5)
minimize_structure(initial_params)
def loss(test_params):
opt_final = minimize_structure(test_params)
return 1.0 - get_params(opt_final)[0]
loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2)
J1 = jacrev(loss, argnums=(0,))(initial_params)
J2 = jacfwd(loss, argnums=(0,))(initial_params)
self.assertAllClose(J1, J2, rtol=1e-6)
def testUnpackPackRoundTrip(self):
opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
params = [{'w': self.rng().randn(1, 2), 'bias': self.rng().randn(2)}]
expected = opt_init(params)
ans = optimizers.pack_optimizer_state(
optimizers.unpack_optimizer_state(expected))
self.assertEqual(ans, expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())