mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
revies optimizers api, fix misc bugs
* add more optimizers numerical tests * update examples and readme with new optimziers api * add device_values parameter to xla_call * change optimizers.py to flatten trees and subtrees * remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap * add optimizer tests: DeviceTuples and error msgs * make the device_values arg to jit private
This commit is contained in:
parent
efe99c8a84
commit
642d2dc802
16
README.md
16
README.md
@ -596,10 +596,12 @@ predictions = net_apply(net_params, inputs)
|
||||
### First-order optimization
|
||||
|
||||
JAX has a minimal optimization library focused on stochastic first-order
|
||||
optimizers. Every optimizer is modeled as an `(init_fun, update_fun)` pair. The
|
||||
`init_fun` is used to initialize the optimizer state, which could include things
|
||||
like momentum variables, and the `update_fun` accepts a gradient and an
|
||||
optimizer state to produce a new optimizer state. The parameters being optimized
|
||||
optimizers. Every optimizer is modeled as an `(init_fun, update_fun,
|
||||
get_params)` triple of functions. The `init_fun` is used to initialize the
|
||||
optimizer state, which could include things like momentum variables, and the
|
||||
`update_fun` accepts a gradient and an optimizer state to produce a new
|
||||
optimizer state. The `get_params` function extracts the current iterate (i.e.
|
||||
the current parameters) from the optimizer state. The parameters being optimized
|
||||
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
|
||||
store your parameters however you’d like.
|
||||
|
||||
@ -616,12 +618,12 @@ def loss(params, batch):
|
||||
return np.sum((predictions - targets)**2)
|
||||
|
||||
# Use optimizers to set optimizer initialization and update functions
|
||||
opt_init, opt_update = optimizers.momentum(step_size=1e-3, mass=0.9)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)
|
||||
|
||||
# Define a compiled update step
|
||||
@jit
|
||||
def step(i, opt_state, batch):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
g = grad(loss)(params, batch)
|
||||
return opt_update(i, g, opt_state)
|
||||
|
||||
@ -633,7 +635,7 @@ data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10)))
|
||||
opt_state = opt_init(net_params)
|
||||
for i in range(10):
|
||||
opt_state = step(i, opt_state, next(data_generator))
|
||||
net_params = optimizers.get_params(opt_state)
|
||||
net_params = get_params(opt_state)
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
@ -120,12 +120,12 @@ if __name__ == "__main__":
|
||||
init_mean = np.zeros(D)
|
||||
init_std = np.zeros(D)
|
||||
init_params = (init_mean, init_std)
|
||||
opt_init, opt_update = optimizers.momentum(step_size=0.1, mass=0.9)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9)
|
||||
opt_state = opt_init(init_params)
|
||||
|
||||
@jit
|
||||
def update(i, opt_state):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
gradient = grad(objective)(params, i)
|
||||
return opt_update(i, gradient, opt_state)
|
||||
|
||||
@ -134,6 +134,6 @@ if __name__ == "__main__":
|
||||
print("Optimizing variational parameters...")
|
||||
for t in range(100):
|
||||
opt_state = update(t, opt_state)
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
callback(params, t)
|
||||
plt.show(block=True)
|
||||
|
@ -204,16 +204,16 @@ def main(_):
|
||||
|
||||
batches = data_stream()
|
||||
|
||||
opt_init, opt_update = optimizers.sgd(FLAGS.learning_rate)
|
||||
opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)
|
||||
|
||||
@jit
|
||||
def update(_, i, opt_state, batch):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
@jit
|
||||
def private_update(rng, i, opt_state, batch):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
rng = random.fold_in(rng, i) # get new key for new random numbers
|
||||
return opt_update(
|
||||
i,
|
||||
@ -243,7 +243,7 @@ def main(_):
|
||||
print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))
|
||||
|
||||
# evaluate test accuracy
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
test_acc = accuracy(params, shape_as_image(test_images, test_labels))
|
||||
test_loss = loss(params, shape_as_image(test_images, test_labels))
|
||||
print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(
|
||||
|
@ -42,17 +42,17 @@ def gram(kernel, xs):
|
||||
|
||||
|
||||
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
|
||||
opt_init, opt_update = optimizers.momentum(step_size, mass)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass)
|
||||
|
||||
@jit
|
||||
def update(i, opt_state):
|
||||
x = optimizers.get_params(opt_state)
|
||||
x = get_params(opt_state)
|
||||
return opt_update(i, grad(f)(x), opt_state)
|
||||
|
||||
opt_state = opt_init(x)
|
||||
for i in xrange(num_steps):
|
||||
opt_state = update(i, opt_state)
|
||||
return optimizers.get_params(opt_state)
|
||||
return get_params(opt_state)
|
||||
|
||||
|
||||
def train(kernel, xs, ys, regularization=0.01):
|
||||
|
@ -72,11 +72,11 @@ if __name__ == "__main__":
|
||||
yield train_images[batch_idx], train_labels[batch_idx]
|
||||
batches = data_stream()
|
||||
|
||||
opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
|
||||
|
||||
@jit
|
||||
def update(i, opt_state, batch):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
_, init_params = init_random_params(rng, (-1, 28 * 28))
|
||||
@ -90,7 +90,7 @@ if __name__ == "__main__":
|
||||
opt_state = update(next(itercount), opt_state, next(batches))
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
train_acc = accuracy(params, (train_images, train_labels))
|
||||
test_acc = accuracy(params, (test_images, test_labels))
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
|
@ -102,7 +102,7 @@ if __name__ == "__main__":
|
||||
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
|
||||
init_params = init_encoder_params, init_decoder_params
|
||||
|
||||
opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
|
||||
|
||||
def binarize_batch(rng, i, images):
|
||||
i = i % num_batches
|
||||
@ -115,13 +115,13 @@ if __name__ == "__main__":
|
||||
elbo_rng, data_rng = random.split(random.fold_in(rng, i))
|
||||
batch = binarize_batch(data_rng, i, train_images)
|
||||
loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
|
||||
g = grad(loss)(optimizers.get_params(opt_state))
|
||||
g = grad(loss)(get_params(opt_state))
|
||||
return opt_update(i, g, opt_state)
|
||||
return lax.fori_loop(0, num_batches, body_fun, opt_state)
|
||||
|
||||
@jit
|
||||
def evaluate(opt_state, images):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
|
||||
binarized_test = random.bernoulli(data_rng, images)
|
||||
test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]
|
||||
|
@ -117,16 +117,16 @@ if __name__ == "__main__":
|
||||
onehot_labels = labels == np.arange(num_classes)
|
||||
yield images, onehot_labels
|
||||
|
||||
opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
|
||||
batches = synth_batches()
|
||||
|
||||
@jit
|
||||
def update(i, opt_state, batch):
|
||||
params = optimizers.get_params(opt_state)
|
||||
params = get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
opt_state = opt_init(init_params)
|
||||
for i in xrange(num_steps):
|
||||
opt_state = update(i, opt_state, next(batches))
|
||||
trained_params = optimizers.get_params(opt_state)
|
||||
trained_params = get_params(opt_state)
|
||||
|
||||
|
@ -24,12 +24,14 @@ from __future__ import print_function
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import numpy as onp
|
||||
import numpy.random as npr
|
||||
|
||||
from jax import jit, grad, pmap, replicate, unreplicate
|
||||
from jax import jit, grad, pmap
|
||||
from jax.config import config
|
||||
from jax.scipy.special import logsumexp
|
||||
from jax.lib import xla_bridge
|
||||
from jax.tree_util import tree_map
|
||||
from jax import lax
|
||||
import jax.numpy as np
|
||||
from examples import datasets
|
||||
@ -74,9 +76,8 @@ if __name__ == "__main__":
|
||||
num_complete_batches, leftover = divmod(num_train, batch_size)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
|
||||
|
||||
# For this manual SPMD example, we get the number of devices (e.g. GPUs or
|
||||
# TPUs) that we're using, and use it to reshape data minibatches.
|
||||
# TPU cores) that we're using, and use it to reshape data minibatches.
|
||||
num_devices = xla_bridge.device_count()
|
||||
def data_stream():
|
||||
rng = npr.RandomState(0)
|
||||
@ -106,20 +107,23 @@ if __name__ == "__main__":
|
||||
return [(w - step_size * dw, b - step_size * db)
|
||||
for (w, b), (dw, db) in zip(params, grads)]
|
||||
|
||||
# We replicate parameters out across devices. (Check the implementation of
|
||||
# replicate; analogous to device_put, it's a simple wrapper around pmap.)
|
||||
params = replicate(init_random_params(param_scale, layer_sizes))
|
||||
# We replicate the parameters so that the constituent arrays have a leading
|
||||
# dimension of size equal to the number of devices we're pmapping over.
|
||||
init_params = init_random_params(param_scale, layer_sizes)
|
||||
replicate_array = lambda x: onp.broadcast_to(x, (num_devices,) + x.shape)
|
||||
replicated_params = tree_map(replicate_array, init_params)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
start_time = time.time()
|
||||
for _ in range(num_batches):
|
||||
params = spmd_update(params, next(batches))
|
||||
replicated_params = spmd_update(replicated_params, next(batches))
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
# We evaluate using the jitted `accuracy` function (not using pmap) by
|
||||
# grabbing just one of the replicated parameter values.
|
||||
train_acc = accuracy(unreplicate(params), (train_images, train_labels))
|
||||
test_acc = accuracy(unreplicate(params), (test_images, test_labels))
|
||||
params = tree_map(lambda x: x[0], replicated_params)
|
||||
train_acc = accuracy(params, (train_images, train_labels))
|
||||
test_acc = accuracy(params, (test_images, test_labels))
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
16
jax/api.py
16
jax/api.py
@ -78,12 +78,12 @@ def jit(fun, static_argnums=()):
|
||||
provided they are hashable and have an equality operation defined. Static
|
||||
arguments are included as part of a compilation cache key, which is why
|
||||
hash and equality operators must be defined.
|
||||
static_argnums: A tuple of ints. Specifies which positional arguments to
|
||||
static_argnums: A tuple of ints specifying which positional arguments to
|
||||
treat as static (compile-time constant). Operations that only depend on
|
||||
static arguments will be constant-folded. Calling the jitted function with
|
||||
different values for these constants will trigger recompilation. If the
|
||||
jitted function is called with fewer positional arguments than indicated
|
||||
by `static_argnums` then an error is raised.
|
||||
by `static_argnums` then an error is raised. Defaults to ().
|
||||
|
||||
Returns:
|
||||
A wrapped version of `fun`, set up for just-in-time compilation.
|
||||
@ -101,6 +101,9 @@ def jit(fun, static_argnums=()):
|
||||
[-0.54485154 0.27744263 -0.29255125 -0.91421586 -0.62452525 -0.2474813
|
||||
-0.8574326 -0.7823267 0.7682731 0.59566754]
|
||||
"""
|
||||
return _jit(fun, static_argnums)
|
||||
|
||||
def _jit(fun, static_argnums, device_values=True):
|
||||
@wraps(fun)
|
||||
def f_jitted(*args, **kwargs):
|
||||
if _jit_is_disabled or config.read('jax_disable_jit'):
|
||||
@ -115,7 +118,7 @@ def jit(fun, static_argnums=()):
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
_check_args(args_flat)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
out = xla.xla_call(flat_fun, *args_flat)
|
||||
out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
jitted_name = "jit({}, static_argnums={})"
|
||||
@ -807,12 +810,7 @@ tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
|
||||
|
||||
|
||||
device_put = jit(lambda x: x)
|
||||
_device_get_array = lambda x: x.copy() if type(x) is xla.DeviceArray else x
|
||||
device_get = partial(tree_map, _device_get_array)
|
||||
|
||||
_replicate_array = lambda x: onp.broadcast_to(x, (device_count(),) + onp.shape(x))
|
||||
replicate = partial(tree_map, _replicate_array)
|
||||
unreplicate = lambda x: tree_map(op.itemgetter(0), x)
|
||||
device_get = _jit(lambda x: x, (), device_values=False)
|
||||
|
||||
|
||||
def _argnums_partial(f, dyn_argnums, args):
|
||||
|
17
jax/core.py
17
jax/core.py
@ -545,7 +545,7 @@ def apply_todos(todos, x):
|
||||
return x
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def process_env_traces(primitive, level, *args):
|
||||
def process_env_traces(primitive, level, params_tuple, *args):
|
||||
ans = yield args, {}
|
||||
todo = []
|
||||
while isinstance(ans, Tracer) and ans.trace.level > level:
|
||||
@ -553,25 +553,26 @@ def process_env_traces(primitive, level, *args):
|
||||
sublevel = cur_sublevel()
|
||||
trace = type(t)(t.master, sublevel)
|
||||
ans = trace.full_raise(ans)
|
||||
ans, cur_todo = ans.trace.post_process_call(primitive, ans)
|
||||
ans, cur_todo = ans.trace.post_process_call(primitive, ans, dict(params_tuple))
|
||||
todo.append(cur_todo)
|
||||
yield ans, todo
|
||||
|
||||
def call_bind(primitive, f, *args, **kwargs):
|
||||
def call_bind(primitive, f, *args, **params):
|
||||
top_trace = find_top_trace(args)
|
||||
level = trace_stack.next_level(True) if top_trace is None else top_trace.level
|
||||
f, env_trace_todo = process_env_traces(f, primitive, level)
|
||||
params_tuple = tuple(params.items())
|
||||
f, env_trace_todo = process_env_traces(f, primitive, level, params_tuple)
|
||||
if top_trace is None:
|
||||
with new_sublevel():
|
||||
ans = primitive.impl(f, *args, **kwargs)
|
||||
ans = primitive.impl(f, *args, **params)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs))
|
||||
ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
|
||||
return apply_todos(env_trace_todo(), ans)
|
||||
|
||||
|
||||
def call_impl(f, *args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
def call_impl(f, *args, **params):
|
||||
return f(*args, **params)
|
||||
|
||||
|
||||
call_p = Primitive('call')
|
||||
|
@ -29,35 +29,60 @@ import operator
|
||||
import jax.numpy as np
|
||||
from jax.core import pack
|
||||
from jax.util import partial, safe_zip, safe_map, unzip2
|
||||
from jax.tree_util import (tree_map, tree_mimomap, tree_structure,
|
||||
register_pytree_node)
|
||||
from jax import tree_util
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
OptimizerState = collections.namedtuple("OptimizerState",
|
||||
["packed_state", "tree", "subtrees"])
|
||||
register_pytree_node(OptimizerState,
|
||||
lambda xs: ((xs.packed_state,), (xs.tree, xs.subtrees)),
|
||||
lambda data, xs: OptimizerState(xs[0], data[0], data[1]))
|
||||
|
||||
def optimizer(opt_maker):
|
||||
"""Decorator to make an optimizer map over tuple/list/dict containers."""
|
||||
@functools.wraps(opt_maker)
|
||||
def tree_opt_maker(*args, **kwargs):
|
||||
init_fun, update_fun = opt_maker(*args, **kwargs)
|
||||
init, update, get_params = opt_maker(*args, **kwargs)
|
||||
|
||||
@functools.wraps(init_fun)
|
||||
def tree_init_fun(x0_tree):
|
||||
return tree_mimomap(init_fun, x0_tree)
|
||||
@functools.wraps(init)
|
||||
def tree_init(x0_tree):
|
||||
x0_flat, tree = tree_flatten(x0_tree)
|
||||
initial_states = [init(x0) for x0 in x0_flat]
|
||||
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
|
||||
packed_state = pack(map(pack, states_flat))
|
||||
return OptimizerState(packed_state, tree, subtrees)
|
||||
|
||||
@functools.wraps(update_fun)
|
||||
def tree_update_fun(i, grad_tree, state_trees):
|
||||
return tree_mimomap(partial(update_fun, i), grad_tree, *state_trees)
|
||||
@functools.wraps(update)
|
||||
def tree_update(i, grad_tree, opt_state):
|
||||
packed_state, tree, subtrees = opt_state
|
||||
grad_flat, tree2 = tree_flatten(grad_tree)
|
||||
assert tree == tree2
|
||||
states = map(tree_unflatten, subtrees, packed_state)
|
||||
new_states = map(partial(update, i), grad_flat, states)
|
||||
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
|
||||
for subtree, subtree2 in zip(subtrees, subtrees2):
|
||||
if subtree != subtree2:
|
||||
msg = ("optimizer update function produced an output structure that "
|
||||
"did not match its input structure: input {} and output {}.")
|
||||
raise TypeError(msg.format(subtree, subtree2))
|
||||
new_packed_state = pack(map(pack, new_states_flat))
|
||||
return OptimizerState(new_packed_state, tree, subtrees)
|
||||
|
||||
return tree_init_fun, tree_update_fun
|
||||
@functools.wraps(get_params)
|
||||
def tree_get_params(opt_state):
|
||||
packed_state, tree, subtrees = opt_state
|
||||
states = map(tree_unflatten, subtrees, packed_state)
|
||||
params = map(get_params, states)
|
||||
return tree_unflatten(tree, params)
|
||||
|
||||
return tree_init, tree_update, tree_get_params
|
||||
return tree_opt_maker
|
||||
|
||||
def iterate(state_trees):
|
||||
"""Extract the current iterate from an optimizer state."""
|
||||
return state_trees[0]
|
||||
get_params = iterate
|
||||
|
||||
# optimizers
|
||||
### optimizers
|
||||
|
||||
@optimizer
|
||||
def sgd(step_size):
|
||||
@ -68,14 +93,16 @@ def sgd(step_size):
|
||||
that maps the iteration index to positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
An (init, update) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
return (x0,)
|
||||
def update_fun(i, g, x):
|
||||
return (x - step_size(i) * g,)
|
||||
return init_fun, update_fun
|
||||
def init(x0):
|
||||
return x0
|
||||
def update(i, g, x):
|
||||
return x - step_size(i) * g
|
||||
def get_params(x):
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
@optimizer
|
||||
def momentum(step_size, mass):
|
||||
@ -86,17 +113,21 @@ def momentum(step_size, mass):
|
||||
that maps the iteration index to positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
An (init, update) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
def init(x0):
|
||||
v0 = np.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update_fun(i, g, x, velocity):
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
velocity = mass * velocity - (1. - mass) * g
|
||||
x = x + step_size(i) * velocity
|
||||
return x, velocity
|
||||
return init_fun, update_fun
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
@optimizer
|
||||
def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
@ -107,17 +138,21 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
that maps the iteration index to positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
An (init, update) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
def init(x0):
|
||||
avg_sq_grad = np.ones_like(x0)
|
||||
return x0, avg_sq_grad
|
||||
def update_fun(i, g, x, avg_sq_grad):
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
|
||||
x = x - step_size(i) * g / (np.sqrt(avg_sq_grad) + eps)
|
||||
return x, avg_sq_grad
|
||||
return init_fun, update_fun
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
@optimizer
|
||||
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
@ -134,23 +169,28 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
numerical stability (default 1e-8).
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
An (init, update) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
def init(x0):
|
||||
m0 = np.zeros_like(x0)
|
||||
v0 = np.zeros_like(x0)
|
||||
return x0, m0, v0
|
||||
def update_fun(i, g, x, m, v):
|
||||
def update(i, g, state):
|
||||
x, m, v = state
|
||||
m = (1 - b1) * g + b1 * m # First moment estimate.
|
||||
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
|
||||
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
|
||||
vhat = v / (1 - b2 ** (i + 1))
|
||||
x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
|
||||
return x, m, v
|
||||
return init_fun, update_fun
|
||||
def get_params(state):
|
||||
x, m, v = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
# learning rate schedules
|
||||
|
||||
### learning rate schedules
|
||||
|
||||
def constant(step_size):
|
||||
def schedule(i):
|
||||
@ -183,10 +223,10 @@ def piecewise_constant(boundaries, values):
|
||||
return values[np.sum(i > boundaries)]
|
||||
return schedule
|
||||
|
||||
def make_schedule(scalar_or_schedule_fun):
|
||||
if callable(scalar_or_schedule_fun):
|
||||
return scalar_or_schedule_fun
|
||||
elif np.ndim(scalar_or_schedule_fun) == 0:
|
||||
return constant(scalar_or_schedule_fun)
|
||||
def make_schedule(scalar_or_schedule):
|
||||
if callable(scalar_or_schedule):
|
||||
return scalar_or_schedule
|
||||
elif np.ndim(scalar_or_schedule) == 0:
|
||||
return constant(scalar_or_schedule)
|
||||
else:
|
||||
raise TypeError(type(scalar_or_schedule_fun))
|
||||
raise TypeError(type(scalar_or_schedule))
|
||||
|
@ -224,7 +224,7 @@ class JVPTrace(Trace):
|
||||
primal_out, tangent_out = build_tree(out_tree_def(), result)
|
||||
return JVPTracer(self, primal_out, tangent_out)
|
||||
|
||||
def post_process_call(self, _, out_tracer):
|
||||
def post_process_call(self, call_primitive, out_tracer, params):
|
||||
out_jtuple, tree_def = tree_to_jaxtuples((out_tracer.primal, out_tracer.tangent))
|
||||
master = self.master
|
||||
def todo(x):
|
||||
|
@ -131,7 +131,7 @@ class BatchTrace(Trace):
|
||||
val_out = call_primitive.bind(f, *vals, **params)
|
||||
return BatchTracer(self, val_out, dim_out())
|
||||
|
||||
def post_process_call(self, _, out_tracer):
|
||||
def post_process_call(self, call_primitive, out_tracer, params):
|
||||
val, dim = out_tracer.val, out_tracer.batch_dim
|
||||
master = self.master
|
||||
def todo(x):
|
||||
|
@ -145,7 +145,7 @@ class SerialPmapTrace(Trace):
|
||||
val_out = call_primitive.bind(f, *vals, **params)
|
||||
return SerialPmapTracer(self, name, val_out, axis_out())
|
||||
|
||||
def post_process_call(self, _, out_tracer):
|
||||
def post_process_call(self, call_primitive, out_tracer, params):
|
||||
name, val, axis = out_tracer.name, out_tracer.val, out_tracer.axis
|
||||
master = self.master
|
||||
def todo(x):
|
||||
|
@ -111,7 +111,7 @@ class JaxprTrace(Trace):
|
||||
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,), False, params)
|
||||
return JaxprTracer(self, PartialVal((out_pv, out_const)), eqn)
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracer):
|
||||
def post_process_call(self, call_primitive, out_tracer, params):
|
||||
# TODO(mattjj): post_process_map
|
||||
jaxpr, consts, env = tracers_to_jaxpr([], out_tracer)
|
||||
out_pv, out_pv_const = out_tracer.pval
|
||||
@ -123,7 +123,7 @@ class JaxprTrace(Trace):
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
||||
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, {})
|
||||
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, params)
|
||||
return JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), eqn)
|
||||
|
||||
return out, todo
|
||||
|
@ -200,6 +200,11 @@ def device_put_many(xs_and_devices):
|
||||
# manipulate than simple Python builtins, we store the metadata required for
|
||||
# forming the DeviceValue result in special ResultArray / ResultTuple classes.
|
||||
|
||||
# Every JaxType needs to map to an XLA type. However this function's design is
|
||||
# based on the assumption that XLA types can be mapped uniquely back to a
|
||||
# JaxType, i.e. that the mapping is bijective. That assumption could be relaxed,
|
||||
# but it would mean we need to do a bit more bookkeping on the Python side to
|
||||
# track abstract values of outputs.
|
||||
def xla_shape_to_result_shape(xla_shape):
|
||||
if xla_shape.is_tuple():
|
||||
aval = aval_from_xla_shape(xla_shape)
|
||||
@ -232,7 +237,8 @@ def pyval_result_handler(result_shape):
|
||||
if t is ResultArray:
|
||||
return lambda buf: buf.to_py()
|
||||
elif t is ResultTuple:
|
||||
handlers = list(map(pyval_result_handler, result_shape))
|
||||
_, result_shapes = result_shape
|
||||
handlers = list(map(pyval_result_handler, result_shapes))
|
||||
return lambda buf: JaxTuple(h(b) for h, b in zip(handlers, buf.destructure()))
|
||||
else:
|
||||
raise TypeError(t)
|
||||
@ -548,8 +554,9 @@ def xla_shape(x):
|
||||
raise TypeError(type(x))
|
||||
|
||||
|
||||
def xla_call_impl(fun, *args):
|
||||
compiled_fun = xla_callable(fun, *map(abstractify, args))
|
||||
def xla_call_impl(fun, *args, **params):
|
||||
device_values = FLAGS.jax_device_values and params.pop('device_values')
|
||||
compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
|
||||
try:
|
||||
return compiled_fun(*args)
|
||||
except FloatingPointError:
|
||||
@ -559,14 +566,17 @@ def xla_call_impl(fun, *args):
|
||||
|
||||
|
||||
@lu.memoize
|
||||
def xla_callable(fun, *abstract_args):
|
||||
def xla_callable(fun, device_values, *abstract_args):
|
||||
pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
|
||||
with core.new_master(pe.JaxprTrace, True) as master:
|
||||
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
|
||||
assert not env # no subtraces here (though cond might eventually need them)
|
||||
compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
|
||||
del master, consts, jaxpr, env
|
||||
handle_result = result_handler(result_shape)
|
||||
if device_values:
|
||||
handle_result = device_persistent_result_handler(result_shape)
|
||||
else:
|
||||
handle_result = pyval_result_handler(result_shape)
|
||||
return partial(execute_compiled, compiled, pval, handle_result)
|
||||
|
||||
def execute_compiled(compiled, pval, handle_result, *args):
|
||||
@ -576,7 +586,7 @@ def execute_compiled(compiled, pval, handle_result, *args):
|
||||
return pe.merge_pvals(handle_result(out_buf), pval)
|
||||
|
||||
|
||||
def xla_call_translation_rule(c, subc_a1, *a2):
|
||||
def xla_call_translation_rule(c, subc_a1, *a2, **params):
|
||||
subc, a1 = subc_a1
|
||||
return c.Call(subc, a1 + a2)
|
||||
|
||||
|
@ -152,6 +152,7 @@ def check_vjp(f, f_vjp, args, atol=ATOL, rtol=RTOL, eps=EPS):
|
||||
|
||||
|
||||
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
|
||||
args = tuple(args)
|
||||
if order > 1:
|
||||
def f_vjp(*args):
|
||||
out_primal_py, vjp_py = api.vjp(f, *args)
|
||||
|
@ -79,71 +79,24 @@ def tree_multimap(f, tree, *rest):
|
||||
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
|
||||
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
|
||||
"""
|
||||
# equivalent to prefix_multimap(f, tree_structure(tree), tree, *rest)
|
||||
node_type = node_types.get(type(tree))
|
||||
if node_type:
|
||||
children, node_spec = node_type.to_iterable(tree)
|
||||
children, aux_data = node_type.to_iterable(tree)
|
||||
all_children = [children]
|
||||
for other_tree in rest:
|
||||
other_node_type = node_types.get(type(other_tree))
|
||||
# TODO(mattjj): enable this check
|
||||
# if node_type != other_node_type:
|
||||
# raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
|
||||
other_children, other_node_data = node_type.to_iterable(other_tree)
|
||||
if other_node_data != node_spec:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_spec))
|
||||
if node_type != other_node_type:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
|
||||
other_children, other_aux_data = node_type.to_iterable(other_tree)
|
||||
if other_aux_data != aux_data:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_aux_data, aux_data))
|
||||
all_children.append(other_children)
|
||||
|
||||
new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
|
||||
return node_type.from_iterable(node_spec, new_children)
|
||||
return node_type.from_iterable(aux_data, new_children)
|
||||
else:
|
||||
return f(tree, *rest)
|
||||
|
||||
def prefix_multimap(f, treedef, tree, *rest):
|
||||
"""Like tree_multimap but only maps down through a tree prefix."""
|
||||
if treedef is leaf:
|
||||
return f(tree, *rest)
|
||||
else:
|
||||
node_type = node_types.get(type(tree))
|
||||
if node_type != treedef.node_type:
|
||||
raise TypeError('Mismatch: {} != {}'.format(treedef.node_type, node_type))
|
||||
children, node_data = node_type.to_iterable(tree)
|
||||
if node_data != treedef.node_data:
|
||||
raise TypeError('Mismatch: {} != {}'.format(treedef.node_data, node_data))
|
||||
all_children = [children]
|
||||
for other_tree in rest:
|
||||
other_children, other_node_data = node_type.to_iterable(other_tree)
|
||||
if other_node_data != node_data:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_data))
|
||||
all_children.append(other_children)
|
||||
all_children = zip(*all_children)
|
||||
|
||||
new_children = [prefix_multimap(f, td, *xs)
|
||||
for td, xs in zip(treedef.children, all_children)]
|
||||
return node_type.from_iterable(node_data, new_children)
|
||||
|
||||
def tree_mimomap(f, tree, *rest):
|
||||
"""Map a multi-input tuple-output over pytree args to form a tuple of pytrees.
|
||||
|
||||
Args:
|
||||
f: function that takes `1 + len(rest)` arguments and returns a tuple, to be
|
||||
applied at the corresponding leaves of the pytrees.
|
||||
tree: a pytree to be mapped over, with each leaf providing the first
|
||||
positional argument to `f`.
|
||||
*rest: a tuple of pytrees, each with the same structure as `tree`.
|
||||
|
||||
Returns:
|
||||
A tuple of pytrees with length given by the length of the output of `f` and
|
||||
with each pytree element having the same structure as `tree`.
|
||||
"""
|
||||
flat, treedef = tree_flatten(tree)
|
||||
rest_flat, treedefs = unzip2(map(tree_flatten, rest))
|
||||
if not all(td == treedef for td in treedefs):
|
||||
td = next(td for td in treedefs if td != treedef)
|
||||
raise TypeError('Mismatch: {} != {}'.format(treedef, td))
|
||||
out_flat = zip(*map(f, flat, *rest_flat))
|
||||
return tuple(map(partial(tree_unflatten, treedef), out_flat))
|
||||
|
||||
|
||||
def tree_reduce(f, tree):
|
||||
flat, _ = tree_flatten(tree)
|
||||
@ -270,6 +223,9 @@ class NodeType(object):
|
||||
self.to_iterable = to_iterable
|
||||
self.from_iterable = from_iterable
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
node_types = {}
|
||||
|
||||
def register_pytree_node(py_type, to_iterable, from_iterable):
|
||||
|
@ -188,19 +188,19 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"opt_init, opt_update = optimizers.adam(step_size=1e-2)\n",
|
||||
"opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)\n",
|
||||
"opt_state = opt_init(net_params)\n",
|
||||
"\n",
|
||||
"# Define a compiled update step\n",
|
||||
"@jit\n",
|
||||
"def step(i, opt_state, x1, y1):\n",
|
||||
" p = optimizers.get_params(opt_state)\n",
|
||||
" p = get_params(opt_state)\n",
|
||||
" g = grad(loss)(p, x1, y1)\n",
|
||||
" return opt_update(i, g, opt_state)\n",
|
||||
"\n",
|
||||
"for i in range(100):\n",
|
||||
" opt_state = step(i, opt_state, xrange_inputs, targets)\n",
|
||||
"net_params = optimizers.get_params(opt_state)"
|
||||
"net_params = get_params(opt_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -23,12 +23,16 @@ from absl.testing import absltest
|
||||
import jax.numpy as np
|
||||
import jax.test_util as jtu
|
||||
from jax import jit, grad
|
||||
from jax import core, tree_util
|
||||
from jax.experimental import optimizers
|
||||
from jax.lib import xla_bridge as xla
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
dummy_data = None
|
||||
|
||||
class OptimizerTests(jtu.JaxTestCase):
|
||||
|
||||
def _CheckOptimizer(self, optimizer, loss, x0, num_steps, *args, **kwargs):
|
||||
@ -36,29 +40,33 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
self._CheckRun(optimizer, loss, x0, num_steps, *args, **kwargs)
|
||||
|
||||
def _CheckFuns(self, optimizer, loss, x0, *args):
|
||||
init_fun, update_fun = optimizer(*args)
|
||||
init_fun, update_fun, get_params_fun = optimizer(*args)
|
||||
opt_state = init_fun(x0)
|
||||
update_fun(0, grad(loss)(x0, None), opt_state) # doesn't crash
|
||||
self.assertAllClose(x0, get_params_fun(opt_state), check_dtypes=True)
|
||||
opt_state2 = update_fun(0, grad(loss)(x0, dummy_data), opt_state) # doesn't crash
|
||||
self.assertEqual(tree_util.tree_structure(opt_state),
|
||||
tree_util.tree_structure(opt_state2))
|
||||
|
||||
@jtu.skip_on_devices('gpu')
|
||||
def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs):
|
||||
return # TODO(mattjj): bring back fax!
|
||||
# num_repl = xla.get_replica_count()
|
||||
# infeeder = fax.make_infeed_from_sequence(
|
||||
# [np.ones(1, dtype='float32')] * num_steps * num_repl,
|
||||
# with_pyvals=True)
|
||||
init_fun, update_fun, get_params_fun = optimizer(*args)
|
||||
|
||||
# def op(infeed, x0):
|
||||
# opt_init, opt_update = optimizer(*args, **kwargs)
|
||||
# return optimizers.run_optimizer(loss, infeed, opt_update, opt_init(x0))
|
||||
# cop = jit(op)
|
||||
opt_state = init_fun(x0)
|
||||
for i in range(num_steps):
|
||||
x = get_params_fun(opt_state)
|
||||
g = grad(loss)(x, dummy_data)
|
||||
opt_state = update_fun(i, g, opt_state)
|
||||
xstar = get_params_fun(opt_state)
|
||||
self.assertLess(loss(xstar, dummy_data), 1e-2)
|
||||
|
||||
# a1, _ = op(infeeder(), x0)
|
||||
# a2, _ = cop(infeeder(), x0)
|
||||
|
||||
# assert loss(a1, None) < 1e-3
|
||||
# assert loss(a2, None) < 1e-3
|
||||
# self.assertAllClose(a1, a2, check_dtypes=False)
|
||||
update_fun_jitted = jit(update_fun)
|
||||
opt_state = init_fun(x0)
|
||||
for i in range(num_steps):
|
||||
x = get_params_fun(opt_state)
|
||||
g = grad(loss)(x, dummy_data)
|
||||
opt_state = update_fun_jitted(i, g, opt_state)
|
||||
xstar = get_params_fun(opt_state)
|
||||
self.assertLess(loss(xstar, dummy_data), 1e-2)
|
||||
|
||||
def testSgdScalar(self):
|
||||
def loss(x, _): return x**2
|
||||
@ -91,6 +99,14 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
mass = 0.
|
||||
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
|
||||
|
||||
def testMomentumDict(self):
|
||||
def loss(dct, _): return np.dot(dct['x'], dct['x'])
|
||||
x0 = {'x': np.ones(2)}
|
||||
num_iters = 100
|
||||
step_size = 0.1
|
||||
mass = 0.
|
||||
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
|
||||
|
||||
def testRmspropVector(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
@ -118,57 +134,76 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
def testSgdVectorExponentialDecaySchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
|
||||
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule)
|
||||
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
|
||||
|
||||
def testSgdVectorInverseTimeDecaySchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
|
||||
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule)
|
||||
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
|
||||
|
||||
def testAdamVectorInverseTimeDecaySchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
|
||||
self._CheckOptimizer(optimizers.adam, loss, x0, num_iters, step_schedule)
|
||||
self._CheckFuns(optimizers.adam, loss, x0, step_schedule)
|
||||
|
||||
def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_sched = optimizers.inverse_time_decay(0.1, 3, 2., staircase=True)
|
||||
mass = 0.9
|
||||
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_sched, mass)
|
||||
self._CheckFuns(optimizers.momentum, loss, x0, step_sched, mass)
|
||||
|
||||
def testRmspropVectorPiecewiseConstantSchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
|
||||
self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_schedule)
|
||||
self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)
|
||||
|
||||
def testTracedStepSize(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_size = 0.1
|
||||
|
||||
init_fun, _ = optimizers.sgd(step_size)
|
||||
init_fun, _, _ = optimizers.sgd(step_size)
|
||||
opt_state = init_fun(x0)
|
||||
|
||||
@jit
|
||||
def update(opt_state, step_size):
|
||||
_, update_fun = optimizers.sgd(step_size)
|
||||
x = optimizers.get_params(opt_state)
|
||||
_, update_fun, get_params = optimizers.sgd(step_size)
|
||||
x = get_params(opt_state)
|
||||
g = grad(loss)(x, None)
|
||||
return update_fun(0, g, opt_state)
|
||||
|
||||
update(opt_state, 0.9) # doesn't crash
|
||||
|
||||
def testDeviceTupleState(self):
|
||||
init_fun, update_fun, _ = optimizers.sgd(0.1)
|
||||
opt_state = init_fun(np.zeros(3))
|
||||
self.assertIsInstance(opt_state, optimizers.OptimizerState)
|
||||
self.assertIsInstance(opt_state.packed_state, core.JaxTuple)
|
||||
opt_state = jit(update_fun)(0, np.zeros(3), opt_state)
|
||||
self.assertIsInstance(opt_state, optimizers.OptimizerState)
|
||||
self.assertIsInstance(opt_state.packed_state, xla.DeviceTuple)
|
||||
|
||||
def testUpdateFunStructureMismatchErrorMessage(self):
|
||||
@optimizers.optimizer
|
||||
def opt_maker():
|
||||
def init_fun(x0):
|
||||
return {'x': x0}
|
||||
def update_fun(i, g, opt_state):
|
||||
x = opt_state['x']
|
||||
return {'x': x - 0.1 * g, 'v': g} # bug!
|
||||
def get_params(opt_state):
|
||||
return opt_state['x']
|
||||
return init_fun, update_fun, get_params
|
||||
|
||||
init_fun, update_fun, get_params = opt_maker()
|
||||
opt_state = init_fun(np.zeros(3))
|
||||
self.assertRaises(TypeError, lambda: update_fun(opt_state))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user