diff --git a/README.md b/README.md index f4cd7136e..529445cb0 100644 --- a/README.md +++ b/README.md @@ -596,10 +596,12 @@ predictions = net_apply(net_params, inputs) ### First-order optimization JAX has a minimal optimization library focused on stochastic first-order -optimizers. Every optimizer is modeled as an `(init_fun, update_fun)` pair. The -`init_fun` is used to initialize the optimizer state, which could include things -like momentum variables, and the `update_fun` accepts a gradient and an -optimizer state to produce a new optimizer state. The parameters being optimized +optimizers. Every optimizer is modeled as an `(init_fun, update_fun, +get_params)` triple of functions. The `init_fun` is used to initialize the +optimizer state, which could include things like momentum variables, and the +`update_fun` accepts a gradient and an optimizer state to produce a new +optimizer state. The `get_params` function extracts the current iterate (i.e. +the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can store your parameters however you’d like. @@ -616,12 +618,12 @@ def loss(params, batch): return np.sum((predictions - targets)**2) # Use optimizers to set optimizer initialization and update functions -opt_init, opt_update = optimizers.momentum(step_size=1e-3, mass=0.9) +opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9) # Define a compiled update step @jit def step(i, opt_state, batch): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) g = grad(loss)(params, batch) return opt_update(i, g, opt_state) @@ -633,7 +635,7 @@ data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10))) opt_state = opt_init(net_params) for i in range(10): opt_state = step(i, opt_state, next(data_generator)) -net_params = optimizers.get_params(opt_state) +net_params = get_params(opt_state) ``` ## How it works diff --git a/examples/advi.py b/examples/advi.py index 756a6f074..36a468b55 100644 --- a/examples/advi.py +++ b/examples/advi.py @@ -120,12 +120,12 @@ if __name__ == "__main__": init_mean = np.zeros(D) init_std = np.zeros(D) init_params = (init_mean, init_std) - opt_init, opt_update = optimizers.momentum(step_size=0.1, mass=0.9) + opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9) opt_state = opt_init(init_params) @jit def update(i, opt_state): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) gradient = grad(objective)(params, i) return opt_update(i, gradient, opt_state) @@ -134,6 +134,6 @@ if __name__ == "__main__": print("Optimizing variational parameters...") for t in range(100): opt_state = update(t, opt_state) - params = optimizers.get_params(opt_state) + params = get_params(opt_state) callback(params, t) plt.show(block=True) diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index 44a1293fd..02c8d1022 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -204,16 +204,16 @@ def main(_): batches = data_stream() - opt_init, opt_update = optimizers.sgd(FLAGS.learning_rate) + opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) @jit def update(_, i, opt_state, batch): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, @@ -243,7 +243,7 @@ def main(_): print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time)) # evaluate test accuracy - params = optimizers.get_params(opt_state) + params = get_params(opt_state) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels)) print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format( diff --git a/examples/kernel_lsq.py b/examples/kernel_lsq.py index 7d319d84c..b2fb3fba0 100644 --- a/examples/kernel_lsq.py +++ b/examples/kernel_lsq.py @@ -42,17 +42,17 @@ def gram(kernel, xs): def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9): - opt_init, opt_update = optimizers.momentum(step_size, mass) + opt_init, opt_update, get_params = optimizers.momentum(step_size, mass) @jit def update(i, opt_state): - x = optimizers.get_params(opt_state) + x = get_params(opt_state) return opt_update(i, grad(f)(x), opt_state) opt_state = opt_init(x) for i in xrange(num_steps): opt_state = update(i, opt_state) - return optimizers.get_params(opt_state) + return get_params(opt_state) def train(kernel, xs, ys, regularization=0.01): diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index b9139fbc9..16aa8b45a 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -72,11 +72,11 @@ if __name__ == "__main__": yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() - opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass) + opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) @@ -90,7 +90,7 @@ if __name__ == "__main__": opt_state = update(next(itercount), opt_state, next(batches)) epoch_time = time.time() - start_time - params = optimizers.get_params(opt_state) + params = get_params(opt_state) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index ca7b8e511..44f181105 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -102,7 +102,7 @@ if __name__ == "__main__": _, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10)) init_params = init_encoder_params, init_decoder_params - opt_init, opt_update = optimizers.momentum(step_size, mass=0.9) + opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) def binarize_batch(rng, i, images): i = i % num_batches @@ -115,13 +115,13 @@ if __name__ == "__main__": elbo_rng, data_rng = random.split(random.fold_in(rng, i)) batch = binarize_batch(data_rng, i, train_images) loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size - g = grad(loss)(optimizers.get_params(opt_state)) + g = grad(loss)(get_params(opt_state)) return opt_update(i, g, opt_state) return lax.fori_loop(0, num_batches, body_fun, opt_state) @jit def evaluate(opt_state, images): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) elbo_rng, data_rng, image_rng = random.split(test_rng, 3) binarized_test = random.bernoulli(data_rng, images) test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0] diff --git a/examples/resnet50.py b/examples/resnet50.py index e9783ff08..770ecc7fd 100644 --- a/examples/resnet50.py +++ b/examples/resnet50.py @@ -117,16 +117,16 @@ if __name__ == "__main__": onehot_labels = labels == np.arange(num_classes) yield images, onehot_labels - opt_init, opt_update = optimizers.momentum(step_size, mass=0.9) + opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) batches = synth_batches() @jit def update(i, opt_state, batch): - params = optimizers.get_params(opt_state) + params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) opt_state = opt_init(init_params) for i in xrange(num_steps): opt_state = update(i, opt_state, next(batches)) - trained_params = optimizers.get_params(opt_state) + trained_params = get_params(opt_state) diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py index 0559f0567..9afee1ae9 100644 --- a/examples/spmd_mnist_classifier_fromscratch.py +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -24,12 +24,14 @@ from __future__ import print_function from functools import partial import time +import numpy as onp import numpy.random as npr -from jax import jit, grad, pmap, replicate, unreplicate +from jax import jit, grad, pmap from jax.config import config from jax.scipy.special import logsumexp from jax.lib import xla_bridge +from jax.tree_util import tree_map from jax import lax import jax.numpy as np from examples import datasets @@ -74,9 +76,8 @@ if __name__ == "__main__": num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) - # For this manual SPMD example, we get the number of devices (e.g. GPUs or - # TPUs) that we're using, and use it to reshape data minibatches. + # TPU cores) that we're using, and use it to reshape data minibatches. num_devices = xla_bridge.device_count() def data_stream(): rng = npr.RandomState(0) @@ -106,20 +107,23 @@ if __name__ == "__main__": return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] - # We replicate parameters out across devices. (Check the implementation of - # replicate; analogous to device_put, it's a simple wrapper around pmap.) - params = replicate(init_random_params(param_scale, layer_sizes)) + # We replicate the parameters so that the constituent arrays have a leading + # dimension of size equal to the number of devices we're pmapping over. + init_params = init_random_params(param_scale, layer_sizes) + replicate_array = lambda x: onp.broadcast_to(x, (num_devices,) + x.shape) + replicated_params = tree_map(replicate_array, init_params) for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): - params = spmd_update(params, next(batches)) + replicated_params = spmd_update(replicated_params, next(batches)) epoch_time = time.time() - start_time # We evaluate using the jitted `accuracy` function (not using pmap) by # grabbing just one of the replicated parameter values. - train_acc = accuracy(unreplicate(params), (train_images, train_labels)) - test_acc = accuracy(unreplicate(params), (test_images, test_labels)) + params = tree_map(lambda x: x[0], replicated_params) + train_acc = accuracy(params, (train_images, train_labels)) + test_acc = accuracy(params, (test_images, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc)) diff --git a/jax/api.py b/jax/api.py index 70e894969..4fb194f20 100644 --- a/jax/api.py +++ b/jax/api.py @@ -78,12 +78,12 @@ def jit(fun, static_argnums=()): provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined. - static_argnums: A tuple of ints. Specifies which positional arguments to + static_argnums: A tuple of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated - by `static_argnums` then an error is raised. + by `static_argnums` then an error is raised. Defaults to (). Returns: A wrapped version of `fun`, set up for just-in-time compilation. @@ -101,6 +101,9 @@ def jit(fun, static_argnums=()): [-0.54485154 0.27744263 -0.29255125 -0.91421586 -0.62452525 -0.2474813 -0.8574326 -0.7823267 0.7682731 0.59566754] """ + return _jit(fun, static_argnums) + +def _jit(fun, static_argnums, device_values=True): @wraps(fun) def f_jitted(*args, **kwargs): if _jit_is_disabled or config.read('jax_disable_jit'): @@ -115,7 +118,7 @@ def jit(fun, static_argnums=()): args_flat, in_tree = tree_flatten((dyn_args, kwargs)) _check_args(args_flat) flat_fun, out_tree = flatten_fun(f, in_tree) - out = xla.xla_call(flat_fun, *args_flat) + out = xla.xla_call(flat_fun, *args_flat, device_values=device_values) return tree_unflatten(out_tree(), out) jitted_name = "jit({}, static_argnums={})" @@ -807,12 +810,7 @@ tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals) device_put = jit(lambda x: x) -_device_get_array = lambda x: x.copy() if type(x) is xla.DeviceArray else x -device_get = partial(tree_map, _device_get_array) - -_replicate_array = lambda x: onp.broadcast_to(x, (device_count(),) + onp.shape(x)) -replicate = partial(tree_map, _replicate_array) -unreplicate = lambda x: tree_map(op.itemgetter(0), x) +device_get = _jit(lambda x: x, (), device_values=False) def _argnums_partial(f, dyn_argnums, args): diff --git a/jax/core.py b/jax/core.py index c8b502dd2..0f18a1bf4 100644 --- a/jax/core.py +++ b/jax/core.py @@ -545,7 +545,7 @@ def apply_todos(todos, x): return x @lu.transformation_with_aux -def process_env_traces(primitive, level, *args): +def process_env_traces(primitive, level, params_tuple, *args): ans = yield args, {} todo = [] while isinstance(ans, Tracer) and ans.trace.level > level: @@ -553,25 +553,26 @@ def process_env_traces(primitive, level, *args): sublevel = cur_sublevel() trace = type(t)(t.master, sublevel) ans = trace.full_raise(ans) - ans, cur_todo = ans.trace.post_process_call(primitive, ans) + ans, cur_todo = ans.trace.post_process_call(primitive, ans, dict(params_tuple)) todo.append(cur_todo) yield ans, todo -def call_bind(primitive, f, *args, **kwargs): +def call_bind(primitive, f, *args, **params): top_trace = find_top_trace(args) level = trace_stack.next_level(True) if top_trace is None else top_trace.level - f, env_trace_todo = process_env_traces(f, primitive, level) + params_tuple = tuple(params.items()) + f, env_trace_todo = process_env_traces(f, primitive, level, params_tuple) if top_trace is None: with new_sublevel(): - ans = primitive.impl(f, *args, **kwargs) + ans = primitive.impl(f, *args, **params) else: tracers = map(top_trace.full_raise, args) - ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs)) + ans = full_lower(top_trace.process_call(primitive, f, tracers, params)) return apply_todos(env_trace_todo(), ans) -def call_impl(f, *args, **kwargs): - return f(*args, **kwargs) +def call_impl(f, *args, **params): + return f(*args, **params) call_p = Primitive('call') diff --git a/jax/experimental/optimizers.py b/jax/experimental/optimizers.py index 82c0ad2d6..b5925d649 100644 --- a/jax/experimental/optimizers.py +++ b/jax/experimental/optimizers.py @@ -29,35 +29,60 @@ import operator import jax.numpy as np from jax.core import pack from jax.util import partial, safe_zip, safe_map, unzip2 -from jax.tree_util import (tree_map, tree_mimomap, tree_structure, - register_pytree_node) +from jax import tree_util +from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node map = safe_map zip = safe_zip +OptimizerState = collections.namedtuple("OptimizerState", + ["packed_state", "tree", "subtrees"]) +register_pytree_node(OptimizerState, + lambda xs: ((xs.packed_state,), (xs.tree, xs.subtrees)), + lambda data, xs: OptimizerState(xs[0], data[0], data[1])) + def optimizer(opt_maker): """Decorator to make an optimizer map over tuple/list/dict containers.""" @functools.wraps(opt_maker) def tree_opt_maker(*args, **kwargs): - init_fun, update_fun = opt_maker(*args, **kwargs) + init, update, get_params = opt_maker(*args, **kwargs) - @functools.wraps(init_fun) - def tree_init_fun(x0_tree): - return tree_mimomap(init_fun, x0_tree) + @functools.wraps(init) + def tree_init(x0_tree): + x0_flat, tree = tree_flatten(x0_tree) + initial_states = [init(x0) for x0 in x0_flat] + states_flat, subtrees = unzip2(map(tree_flatten, initial_states)) + packed_state = pack(map(pack, states_flat)) + return OptimizerState(packed_state, tree, subtrees) - @functools.wraps(update_fun) - def tree_update_fun(i, grad_tree, state_trees): - return tree_mimomap(partial(update_fun, i), grad_tree, *state_trees) + @functools.wraps(update) + def tree_update(i, grad_tree, opt_state): + packed_state, tree, subtrees = opt_state + grad_flat, tree2 = tree_flatten(grad_tree) + assert tree == tree2 + states = map(tree_unflatten, subtrees, packed_state) + new_states = map(partial(update, i), grad_flat, states) + new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) + for subtree, subtree2 in zip(subtrees, subtrees2): + if subtree != subtree2: + msg = ("optimizer update function produced an output structure that " + "did not match its input structure: input {} and output {}.") + raise TypeError(msg.format(subtree, subtree2)) + new_packed_state = pack(map(pack, new_states_flat)) + return OptimizerState(new_packed_state, tree, subtrees) - return tree_init_fun, tree_update_fun + @functools.wraps(get_params) + def tree_get_params(opt_state): + packed_state, tree, subtrees = opt_state + states = map(tree_unflatten, subtrees, packed_state) + params = map(get_params, states) + return tree_unflatten(tree, params) + + return tree_init, tree_update, tree_get_params return tree_opt_maker -def iterate(state_trees): - """Extract the current iterate from an optimizer state.""" - return state_trees[0] -get_params = iterate -# optimizers +### optimizers @optimizer def sgd(step_size): @@ -68,14 +93,16 @@ def sgd(step_size): that maps the iteration index to positive scalar. Returns: - An (init_fun, update_fun) pair. + An (init, update) pair. """ step_size = make_schedule(step_size) - def init_fun(x0): - return (x0,) - def update_fun(i, g, x): - return (x - step_size(i) * g,) - return init_fun, update_fun + def init(x0): + return x0 + def update(i, g, x): + return x - step_size(i) * g + def get_params(x): + return x + return init, update, get_params @optimizer def momentum(step_size, mass): @@ -86,17 +113,21 @@ def momentum(step_size, mass): that maps the iteration index to positive scalar. Returns: - An (init_fun, update_fun) pair. + An (init, update) pair. """ step_size = make_schedule(step_size) - def init_fun(x0): + def init(x0): v0 = np.zeros_like(x0) return x0, v0 - def update_fun(i, g, x, velocity): + def update(i, g, state): + x, velocity = state velocity = mass * velocity - (1. - mass) * g x = x + step_size(i) * velocity return x, velocity - return init_fun, update_fun + def get_params(state): + x, _ = state + return x + return init, update, get_params @optimizer def rmsprop(step_size, gamma=0.9, eps=1e-8): @@ -107,17 +138,21 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8): that maps the iteration index to positive scalar. Returns: - An (init_fun, update_fun) pair. + An (init, update) pair. """ step_size = make_schedule(step_size) - def init_fun(x0): + def init(x0): avg_sq_grad = np.ones_like(x0) return x0, avg_sq_grad - def update_fun(i, g, x, avg_sq_grad): + def update(i, g, state): + x, avg_sq_grad = state avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma) x = x - step_size(i) * g / (np.sqrt(avg_sq_grad) + eps) return x, avg_sq_grad - return init_fun, update_fun + def get_params(state): + x, _ = state + return x + return init, update, get_params @optimizer def adam(step_size, b1=0.9, b2=0.999, eps=1e-8): @@ -134,23 +169,28 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8): numerical stability (default 1e-8). Returns: - An (init_fun, update_fun) pair. + An (init, update) pair. """ step_size = make_schedule(step_size) - def init_fun(x0): + def init(x0): m0 = np.zeros_like(x0) v0 = np.zeros_like(x0) return x0, m0, v0 - def update_fun(i, g, x, m, v): + def update(i, g, state): + x, m, v = state m = (1 - b1) * g + b1 * m # First moment estimate. v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate. mhat = m / (1 - b1 ** (i + 1)) # Bias correction. vhat = v / (1 - b2 ** (i + 1)) x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps) return x, m, v - return init_fun, update_fun + def get_params(state): + x, m, v = state + return x + return init, update, get_params -# learning rate schedules + +### learning rate schedules def constant(step_size): def schedule(i): @@ -183,10 +223,10 @@ def piecewise_constant(boundaries, values): return values[np.sum(i > boundaries)] return schedule -def make_schedule(scalar_or_schedule_fun): - if callable(scalar_or_schedule_fun): - return scalar_or_schedule_fun - elif np.ndim(scalar_or_schedule_fun) == 0: - return constant(scalar_or_schedule_fun) +def make_schedule(scalar_or_schedule): + if callable(scalar_or_schedule): + return scalar_or_schedule + elif np.ndim(scalar_or_schedule) == 0: + return constant(scalar_or_schedule) else: - raise TypeError(type(scalar_or_schedule_fun)) + raise TypeError(type(scalar_or_schedule)) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 3eefeb681..709d817b5 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -224,7 +224,7 @@ class JVPTrace(Trace): primal_out, tangent_out = build_tree(out_tree_def(), result) return JVPTracer(self, primal_out, tangent_out) - def post_process_call(self, _, out_tracer): + def post_process_call(self, call_primitive, out_tracer, params): out_jtuple, tree_def = tree_to_jaxtuples((out_tracer.primal, out_tracer.tangent)) master = self.master def todo(x): diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 57e1133c4..05aaa2b6a 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -131,7 +131,7 @@ class BatchTrace(Trace): val_out = call_primitive.bind(f, *vals, **params) return BatchTracer(self, val_out, dim_out()) - def post_process_call(self, _, out_tracer): + def post_process_call(self, call_primitive, out_tracer, params): val, dim = out_tracer.val, out_tracer.batch_dim master = self.master def todo(x): diff --git a/jax/interpreters/parallel.py b/jax/interpreters/parallel.py index a35867fda..d14c601b3 100644 --- a/jax/interpreters/parallel.py +++ b/jax/interpreters/parallel.py @@ -145,7 +145,7 @@ class SerialPmapTrace(Trace): val_out = call_primitive.bind(f, *vals, **params) return SerialPmapTracer(self, name, val_out, axis_out()) - def post_process_call(self, _, out_tracer): + def post_process_call(self, call_primitive, out_tracer, params): name, val, axis = out_tracer.name, out_tracer.val, out_tracer.axis master = self.master def todo(x): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index d62a54d69..21741bb23 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -111,7 +111,7 @@ class JaxprTrace(Trace): eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,), False, params) return JaxprTracer(self, PartialVal((out_pv, out_const)), eqn) - def post_process_call(self, call_primitive, out_tracer): + def post_process_call(self, call_primitive, out_tracer, params): # TODO(mattjj): post_process_map jaxpr, consts, env = tracers_to_jaxpr([], out_tracer) out_pv, out_pv_const = out_tracer.pval @@ -123,7 +123,7 @@ class JaxprTrace(Trace): const_tracers = map(trace.new_instantiated_const, consts) env_tracers = map(trace.full_raise, env) bound_subjaxpr = (jaxpr, const_tracers, env_tracers) - eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, {}) + eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, params) return JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), eqn) return out, todo diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 383581d4d..c1548e666 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -200,6 +200,11 @@ def device_put_many(xs_and_devices): # manipulate than simple Python builtins, we store the metadata required for # forming the DeviceValue result in special ResultArray / ResultTuple classes. +# Every JaxType needs to map to an XLA type. However this function's design is +# based on the assumption that XLA types can be mapped uniquely back to a +# JaxType, i.e. that the mapping is bijective. That assumption could be relaxed, +# but it would mean we need to do a bit more bookkeping on the Python side to +# track abstract values of outputs. def xla_shape_to_result_shape(xla_shape): if xla_shape.is_tuple(): aval = aval_from_xla_shape(xla_shape) @@ -232,7 +237,8 @@ def pyval_result_handler(result_shape): if t is ResultArray: return lambda buf: buf.to_py() elif t is ResultTuple: - handlers = list(map(pyval_result_handler, result_shape)) + _, result_shapes = result_shape + handlers = list(map(pyval_result_handler, result_shapes)) return lambda buf: JaxTuple(h(b) for h, b in zip(handlers, buf.destructure())) else: raise TypeError(t) @@ -548,8 +554,9 @@ def xla_shape(x): raise TypeError(type(x)) -def xla_call_impl(fun, *args): - compiled_fun = xla_callable(fun, *map(abstractify, args)) +def xla_call_impl(fun, *args, **params): + device_values = FLAGS.jax_device_values and params.pop('device_values') + compiled_fun = xla_callable(fun, device_values, *map(abstractify, args)) try: return compiled_fun(*args) except FloatingPointError: @@ -559,14 +566,17 @@ def xla_call_impl(fun, *args): @lu.memoize -def xla_callable(fun, *abstract_args): +def xla_callable(fun, device_values, *abstract_args): pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args] with core.new_master(pe.JaxprTrace, True) as master: jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals) assert not env # no subtraces here (though cond might eventually need them) compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args) del master, consts, jaxpr, env - handle_result = result_handler(result_shape) + if device_values: + handle_result = device_persistent_result_handler(result_shape) + else: + handle_result = pyval_result_handler(result_shape) return partial(execute_compiled, compiled, pval, handle_result) def execute_compiled(compiled, pval, handle_result, *args): @@ -576,7 +586,7 @@ def execute_compiled(compiled, pval, handle_result, *args): return pe.merge_pvals(handle_result(out_buf), pval) -def xla_call_translation_rule(c, subc_a1, *a2): +def xla_call_translation_rule(c, subc_a1, *a2, **params): subc, a1 = subc_a1 return c.Call(subc, a1 + a2) diff --git a/jax/test_util.py b/jax/test_util.py index 8f910aaf6..580cf5735 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -152,6 +152,7 @@ def check_vjp(f, f_vjp, args, atol=ATOL, rtol=RTOL, eps=EPS): def check_grads(f, args, order, atol=None, rtol=None, eps=None): + args = tuple(args) if order > 1: def f_vjp(*args): out_primal_py, vjp_py = api.vjp(f, *args) diff --git a/jax/tree_util.py b/jax/tree_util.py index bab0a5437..41f9e470f 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -79,71 +79,24 @@ def tree_multimap(f, tree, *rest): leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`. """ - # equivalent to prefix_multimap(f, tree_structure(tree), tree, *rest) node_type = node_types.get(type(tree)) if node_type: - children, node_spec = node_type.to_iterable(tree) + children, aux_data = node_type.to_iterable(tree) all_children = [children] for other_tree in rest: other_node_type = node_types.get(type(other_tree)) - # TODO(mattjj): enable this check - # if node_type != other_node_type: - # raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type)) - other_children, other_node_data = node_type.to_iterable(other_tree) - if other_node_data != node_spec: - raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_spec)) + if node_type != other_node_type: + raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type)) + other_children, other_aux_data = node_type.to_iterable(other_tree) + if other_aux_data != aux_data: + raise TypeError('Mismatch: {} != {}'.format(other_aux_data, aux_data)) all_children.append(other_children) new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)] - return node_type.from_iterable(node_spec, new_children) + return node_type.from_iterable(aux_data, new_children) else: return f(tree, *rest) -def prefix_multimap(f, treedef, tree, *rest): - """Like tree_multimap but only maps down through a tree prefix.""" - if treedef is leaf: - return f(tree, *rest) - else: - node_type = node_types.get(type(tree)) - if node_type != treedef.node_type: - raise TypeError('Mismatch: {} != {}'.format(treedef.node_type, node_type)) - children, node_data = node_type.to_iterable(tree) - if node_data != treedef.node_data: - raise TypeError('Mismatch: {} != {}'.format(treedef.node_data, node_data)) - all_children = [children] - for other_tree in rest: - other_children, other_node_data = node_type.to_iterable(other_tree) - if other_node_data != node_data: - raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_data)) - all_children.append(other_children) - all_children = zip(*all_children) - - new_children = [prefix_multimap(f, td, *xs) - for td, xs in zip(treedef.children, all_children)] - return node_type.from_iterable(node_data, new_children) - -def tree_mimomap(f, tree, *rest): - """Map a multi-input tuple-output over pytree args to form a tuple of pytrees. - - Args: - f: function that takes `1 + len(rest)` arguments and returns a tuple, to be - applied at the corresponding leaves of the pytrees. - tree: a pytree to be mapped over, with each leaf providing the first - positional argument to `f`. - *rest: a tuple of pytrees, each with the same structure as `tree`. - - Returns: - A tuple of pytrees with length given by the length of the output of `f` and - with each pytree element having the same structure as `tree`. - """ - flat, treedef = tree_flatten(tree) - rest_flat, treedefs = unzip2(map(tree_flatten, rest)) - if not all(td == treedef for td in treedefs): - td = next(td for td in treedefs if td != treedef) - raise TypeError('Mismatch: {} != {}'.format(treedef, td)) - out_flat = zip(*map(f, flat, *rest_flat)) - return tuple(map(partial(tree_unflatten, treedef), out_flat)) - def tree_reduce(f, tree): flat, _ = tree_flatten(tree) @@ -270,6 +223,9 @@ class NodeType(object): self.to_iterable = to_iterable self.from_iterable = from_iterable + def __repr__(self): + return self.name + node_types = {} def register_pytree_node(py_type, to_iterable, from_iterable): diff --git a/notebooks/maml.ipynb b/notebooks/maml.ipynb index f20db80c3..3a1247d52 100644 --- a/notebooks/maml.ipynb +++ b/notebooks/maml.ipynb @@ -188,19 +188,19 @@ "metadata": {}, "outputs": [], "source": [ - "opt_init, opt_update = optimizers.adam(step_size=1e-2)\n", + "opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)\n", "opt_state = opt_init(net_params)\n", "\n", "# Define a compiled update step\n", "@jit\n", "def step(i, opt_state, x1, y1):\n", - " p = optimizers.get_params(opt_state)\n", + " p = get_params(opt_state)\n", " g = grad(loss)(p, x1, y1)\n", " return opt_update(i, g, opt_state)\n", "\n", "for i in range(100):\n", " opt_state = step(i, opt_state, xrange_inputs, targets)\n", - "net_params = optimizers.get_params(opt_state)" + "net_params = get_params(opt_state)" ] }, { diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 2a7f9ff32..a86da5660 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -23,12 +23,16 @@ from absl.testing import absltest import jax.numpy as np import jax.test_util as jtu from jax import jit, grad +from jax import core, tree_util from jax.experimental import optimizers -from jax.lib import xla_bridge as xla +from jax.interpreters import xla from jax.config import config config.parse_flags_with_absl() + +dummy_data = None + class OptimizerTests(jtu.JaxTestCase): def _CheckOptimizer(self, optimizer, loss, x0, num_steps, *args, **kwargs): @@ -36,29 +40,33 @@ class OptimizerTests(jtu.JaxTestCase): self._CheckRun(optimizer, loss, x0, num_steps, *args, **kwargs) def _CheckFuns(self, optimizer, loss, x0, *args): - init_fun, update_fun = optimizer(*args) + init_fun, update_fun, get_params_fun = optimizer(*args) opt_state = init_fun(x0) - update_fun(0, grad(loss)(x0, None), opt_state) # doesn't crash + self.assertAllClose(x0, get_params_fun(opt_state), check_dtypes=True) + opt_state2 = update_fun(0, grad(loss)(x0, dummy_data), opt_state) # doesn't crash + self.assertEqual(tree_util.tree_structure(opt_state), + tree_util.tree_structure(opt_state2)) @jtu.skip_on_devices('gpu') def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs): - return # TODO(mattjj): bring back fax! - # num_repl = xla.get_replica_count() - # infeeder = fax.make_infeed_from_sequence( - # [np.ones(1, dtype='float32')] * num_steps * num_repl, - # with_pyvals=True) + init_fun, update_fun, get_params_fun = optimizer(*args) - # def op(infeed, x0): - # opt_init, opt_update = optimizer(*args, **kwargs) - # return optimizers.run_optimizer(loss, infeed, opt_update, opt_init(x0)) - # cop = jit(op) + opt_state = init_fun(x0) + for i in range(num_steps): + x = get_params_fun(opt_state) + g = grad(loss)(x, dummy_data) + opt_state = update_fun(i, g, opt_state) + xstar = get_params_fun(opt_state) + self.assertLess(loss(xstar, dummy_data), 1e-2) - # a1, _ = op(infeeder(), x0) - # a2, _ = cop(infeeder(), x0) - - # assert loss(a1, None) < 1e-3 - # assert loss(a2, None) < 1e-3 - # self.assertAllClose(a1, a2, check_dtypes=False) + update_fun_jitted = jit(update_fun) + opt_state = init_fun(x0) + for i in range(num_steps): + x = get_params_fun(opt_state) + g = grad(loss)(x, dummy_data) + opt_state = update_fun_jitted(i, g, opt_state) + xstar = get_params_fun(opt_state) + self.assertLess(loss(xstar, dummy_data), 1e-2) def testSgdScalar(self): def loss(x, _): return x**2 @@ -91,6 +99,14 @@ class OptimizerTests(jtu.JaxTestCase): mass = 0. self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass) + def testMomentumDict(self): + def loss(dct, _): return np.dot(dct['x'], dct['x']) + x0 = {'x': np.ones(2)} + num_iters = 100 + step_size = 0.1 + mass = 0. + self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass) + def testRmspropVector(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) @@ -118,57 +134,76 @@ class OptimizerTests(jtu.JaxTestCase): def testSgdVectorExponentialDecaySchedule(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) - num_iters = 100 step_schedule = optimizers.exponential_decay(0.1, 3, 2.) - self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule) + self._CheckFuns(optimizers.sgd, loss, x0, step_schedule) def testSgdVectorInverseTimeDecaySchedule(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) - num_iters = 100 step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.) - self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule) + self._CheckFuns(optimizers.sgd, loss, x0, step_schedule) def testAdamVectorInverseTimeDecaySchedule(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) - num_iters = 100 step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.) - self._CheckOptimizer(optimizers.adam, loss, x0, num_iters, step_schedule) + self._CheckFuns(optimizers.adam, loss, x0, step_schedule) def testMomentumVectorInverseTimeDecayStaircaseSchedule(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) - num_iters = 100 step_sched = optimizers.inverse_time_decay(0.1, 3, 2., staircase=True) mass = 0.9 - self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_sched, mass) + self._CheckFuns(optimizers.momentum, loss, x0, step_sched, mass) def testRmspropVectorPiecewiseConstantSchedule(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) - num_iters = 100 step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1]) - self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_schedule) + self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule) def testTracedStepSize(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) - num_iters = 100 step_size = 0.1 - init_fun, _ = optimizers.sgd(step_size) + init_fun, _, _ = optimizers.sgd(step_size) opt_state = init_fun(x0) @jit def update(opt_state, step_size): - _, update_fun = optimizers.sgd(step_size) - x = optimizers.get_params(opt_state) + _, update_fun, get_params = optimizers.sgd(step_size) + x = get_params(opt_state) g = grad(loss)(x, None) return update_fun(0, g, opt_state) update(opt_state, 0.9) # doesn't crash + def testDeviceTupleState(self): + init_fun, update_fun, _ = optimizers.sgd(0.1) + opt_state = init_fun(np.zeros(3)) + self.assertIsInstance(opt_state, optimizers.OptimizerState) + self.assertIsInstance(opt_state.packed_state, core.JaxTuple) + opt_state = jit(update_fun)(0, np.zeros(3), opt_state) + self.assertIsInstance(opt_state, optimizers.OptimizerState) + self.assertIsInstance(opt_state.packed_state, xla.DeviceTuple) + + def testUpdateFunStructureMismatchErrorMessage(self): + @optimizers.optimizer + def opt_maker(): + def init_fun(x0): + return {'x': x0} + def update_fun(i, g, opt_state): + x = opt_state['x'] + return {'x': x - 0.1 * g, 'v': g} # bug! + def get_params(opt_state): + return opt_state['x'] + return init_fun, update_fun, get_params + + init_fun, update_fun, get_params = opt_maker() + opt_state = init_fun(np.zeros(3)) + self.assertRaises(TypeError, lambda: update_fun(opt_state)) + if __name__ == '__main__': absltest.main()