revies optimizers api, fix misc bugs

* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
This commit is contained in:
Matthew Johnson 2019-05-03 12:37:14 -07:00
parent efe99c8a84
commit 642d2dc802
20 changed files with 239 additions and 192 deletions

View File

@ -596,10 +596,12 @@ predictions = net_apply(net_params, inputs)
### First-order optimization
JAX has a minimal optimization library focused on stochastic first-order
optimizers. Every optimizer is modeled as an `(init_fun, update_fun)` pair. The
`init_fun` is used to initialize the optimizer state, which could include things
like momentum variables, and the `update_fun` accepts a gradient and an
optimizer state to produce a new optimizer state. The parameters being optimized
optimizers. Every optimizer is modeled as an `(init_fun, update_fun,
get_params)` triple of functions. The `init_fun` is used to initialize the
optimizer state, which could include things like momentum variables, and the
`update_fun` accepts a gradient and an optimizer state to produce a new
optimizer state. The `get_params` function extracts the current iterate (i.e.
the current parameters) from the optimizer state. The parameters being optimized
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
store your parameters however youd like.
@ -616,12 +618,12 @@ def loss(params, batch):
return np.sum((predictions - targets)**2)
# Use optimizers to set optimizer initialization and update functions
opt_init, opt_update = optimizers.momentum(step_size=1e-3, mass=0.9)
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)
# Define a compiled update step
@jit
def step(i, opt_state, batch):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
g = grad(loss)(params, batch)
return opt_update(i, g, opt_state)
@ -633,7 +635,7 @@ data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10)))
opt_state = opt_init(net_params)
for i in range(10):
opt_state = step(i, opt_state, next(data_generator))
net_params = optimizers.get_params(opt_state)
net_params = get_params(opt_state)
```
## How it works

View File

@ -120,12 +120,12 @@ if __name__ == "__main__":
init_mean = np.zeros(D)
init_std = np.zeros(D)
init_params = (init_mean, init_std)
opt_init, opt_update = optimizers.momentum(step_size=0.1, mass=0.9)
opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9)
opt_state = opt_init(init_params)
@jit
def update(i, opt_state):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
gradient = grad(objective)(params, i)
return opt_update(i, gradient, opt_state)
@ -134,6 +134,6 @@ if __name__ == "__main__":
print("Optimizing variational parameters...")
for t in range(100):
opt_state = update(t, opt_state)
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
callback(params, t)
plt.show(block=True)

View File

@ -204,16 +204,16 @@ def main(_):
batches = data_stream()
opt_init, opt_update = optimizers.sgd(FLAGS.learning_rate)
opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)
@jit
def update(_, i, opt_state, batch):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
@jit
def private_update(rng, i, opt_state, batch):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
rng = random.fold_in(rng, i) # get new key for new random numbers
return opt_update(
i,
@ -243,7 +243,7 @@ def main(_):
print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))
# evaluate test accuracy
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
test_acc = accuracy(params, shape_as_image(test_images, test_labels))
test_loss = loss(params, shape_as_image(test_images, test_labels))
print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(

View File

@ -42,17 +42,17 @@ def gram(kernel, xs):
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
opt_init, opt_update = optimizers.momentum(step_size, mass)
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass)
@jit
def update(i, opt_state):
x = optimizers.get_params(opt_state)
x = get_params(opt_state)
return opt_update(i, grad(f)(x), opt_state)
opt_state = opt_init(x)
for i in xrange(num_steps):
opt_state = update(i, opt_state)
return optimizers.get_params(opt_state)
return get_params(opt_state)
def train(kernel, xs, ys, regularization=0.01):

View File

@ -72,11 +72,11 @@ if __name__ == "__main__":
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
@ -90,7 +90,7 @@ if __name__ == "__main__":
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))

View File

@ -102,7 +102,7 @@ if __name__ == "__main__":
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
def binarize_batch(rng, i, images):
i = i % num_batches
@ -115,13 +115,13 @@ if __name__ == "__main__":
elbo_rng, data_rng = random.split(random.fold_in(rng, i))
batch = binarize_batch(data_rng, i, train_images)
loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
g = grad(loss)(optimizers.get_params(opt_state))
g = grad(loss)(get_params(opt_state))
return opt_update(i, g, opt_state)
return lax.fori_loop(0, num_batches, body_fun, opt_state)
@jit
def evaluate(opt_state, images):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
binarized_test = random.bernoulli(data_rng, images)
test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]

View File

@ -117,16 +117,16 @@ if __name__ == "__main__":
onehot_labels = labels == np.arange(num_classes)
yield images, onehot_labels
opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
batches = synth_batches()
@jit
def update(i, opt_state, batch):
params = optimizers.get_params(opt_state)
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
opt_state = opt_init(init_params)
for i in xrange(num_steps):
opt_state = update(i, opt_state, next(batches))
trained_params = optimizers.get_params(opt_state)
trained_params = get_params(opt_state)

View File

@ -24,12 +24,14 @@ from __future__ import print_function
from functools import partial
import time
import numpy as onp
import numpy.random as npr
from jax import jit, grad, pmap, replicate, unreplicate
from jax import jit, grad, pmap
from jax.config import config
from jax.scipy.special import logsumexp
from jax.lib import xla_bridge
from jax.tree_util import tree_map
from jax import lax
import jax.numpy as np
from examples import datasets
@ -74,9 +76,8 @@ if __name__ == "__main__":
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
# For this manual SPMD example, we get the number of devices (e.g. GPUs or
# TPUs) that we're using, and use it to reshape data minibatches.
# TPU cores) that we're using, and use it to reshape data minibatches.
num_devices = xla_bridge.device_count()
def data_stream():
rng = npr.RandomState(0)
@ -106,20 +107,23 @@ if __name__ == "__main__":
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
# We replicate parameters out across devices. (Check the implementation of
# replicate; analogous to device_put, it's a simple wrapper around pmap.)
params = replicate(init_random_params(param_scale, layer_sizes))
# We replicate the parameters so that the constituent arrays have a leading
# dimension of size equal to the number of devices we're pmapping over.
init_params = init_random_params(param_scale, layer_sizes)
replicate_array = lambda x: onp.broadcast_to(x, (num_devices,) + x.shape)
replicated_params = tree_map(replicate_array, init_params)
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
params = spmd_update(params, next(batches))
replicated_params = spmd_update(replicated_params, next(batches))
epoch_time = time.time() - start_time
# We evaluate using the jitted `accuracy` function (not using pmap) by
# grabbing just one of the replicated parameter values.
train_acc = accuracy(unreplicate(params), (train_images, train_labels))
test_acc = accuracy(unreplicate(params), (test_images, test_labels))
params = tree_map(lambda x: x[0], replicated_params)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))

View File

@ -78,12 +78,12 @@ def jit(fun, static_argnums=()):
provided they are hashable and have an equality operation defined. Static
arguments are included as part of a compilation cache key, which is why
hash and equality operators must be defined.
static_argnums: A tuple of ints. Specifies which positional arguments to
static_argnums: A tuple of ints specifying which positional arguments to
treat as static (compile-time constant). Operations that only depend on
static arguments will be constant-folded. Calling the jitted function with
different values for these constants will trigger recompilation. If the
jitted function is called with fewer positional arguments than indicated
by `static_argnums` then an error is raised.
by `static_argnums` then an error is raised. Defaults to ().
Returns:
A wrapped version of `fun`, set up for just-in-time compilation.
@ -101,6 +101,9 @@ def jit(fun, static_argnums=()):
[-0.54485154 0.27744263 -0.29255125 -0.91421586 -0.62452525 -0.2474813
-0.8574326 -0.7823267 0.7682731 0.59566754]
"""
return _jit(fun, static_argnums)
def _jit(fun, static_argnums, device_values=True):
@wraps(fun)
def f_jitted(*args, **kwargs):
if _jit_is_disabled or config.read('jax_disable_jit'):
@ -115,7 +118,7 @@ def jit(fun, static_argnums=()):
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
_check_args(args_flat)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = xla.xla_call(flat_fun, *args_flat)
out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
return tree_unflatten(out_tree(), out)
jitted_name = "jit({}, static_argnums={})"
@ -807,12 +810,7 @@ tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
device_put = jit(lambda x: x)
_device_get_array = lambda x: x.copy() if type(x) is xla.DeviceArray else x
device_get = partial(tree_map, _device_get_array)
_replicate_array = lambda x: onp.broadcast_to(x, (device_count(),) + onp.shape(x))
replicate = partial(tree_map, _replicate_array)
unreplicate = lambda x: tree_map(op.itemgetter(0), x)
device_get = _jit(lambda x: x, (), device_values=False)
def _argnums_partial(f, dyn_argnums, args):

View File

@ -545,7 +545,7 @@ def apply_todos(todos, x):
return x
@lu.transformation_with_aux
def process_env_traces(primitive, level, *args):
def process_env_traces(primitive, level, params_tuple, *args):
ans = yield args, {}
todo = []
while isinstance(ans, Tracer) and ans.trace.level > level:
@ -553,25 +553,26 @@ def process_env_traces(primitive, level, *args):
sublevel = cur_sublevel()
trace = type(t)(t.master, sublevel)
ans = trace.full_raise(ans)
ans, cur_todo = ans.trace.post_process_call(primitive, ans)
ans, cur_todo = ans.trace.post_process_call(primitive, ans, dict(params_tuple))
todo.append(cur_todo)
yield ans, todo
def call_bind(primitive, f, *args, **kwargs):
def call_bind(primitive, f, *args, **params):
top_trace = find_top_trace(args)
level = trace_stack.next_level(True) if top_trace is None else top_trace.level
f, env_trace_todo = process_env_traces(f, primitive, level)
params_tuple = tuple(params.items())
f, env_trace_todo = process_env_traces(f, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
ans = primitive.impl(f, *args, **kwargs)
ans = primitive.impl(f, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs))
ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
return apply_todos(env_trace_todo(), ans)
def call_impl(f, *args, **kwargs):
return f(*args, **kwargs)
def call_impl(f, *args, **params):
return f(*args, **params)
call_p = Primitive('call')

View File

@ -29,35 +29,60 @@ import operator
import jax.numpy as np
from jax.core import pack
from jax.util import partial, safe_zip, safe_map, unzip2
from jax.tree_util import (tree_map, tree_mimomap, tree_structure,
register_pytree_node)
from jax import tree_util
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
map = safe_map
zip = safe_zip
OptimizerState = collections.namedtuple("OptimizerState",
["packed_state", "tree", "subtrees"])
register_pytree_node(OptimizerState,
lambda xs: ((xs.packed_state,), (xs.tree, xs.subtrees)),
lambda data, xs: OptimizerState(xs[0], data[0], data[1]))
def optimizer(opt_maker):
"""Decorator to make an optimizer map over tuple/list/dict containers."""
@functools.wraps(opt_maker)
def tree_opt_maker(*args, **kwargs):
init_fun, update_fun = opt_maker(*args, **kwargs)
init, update, get_params = opt_maker(*args, **kwargs)
@functools.wraps(init_fun)
def tree_init_fun(x0_tree):
return tree_mimomap(init_fun, x0_tree)
@functools.wraps(init)
def tree_init(x0_tree):
x0_flat, tree = tree_flatten(x0_tree)
initial_states = [init(x0) for x0 in x0_flat]
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
packed_state = pack(map(pack, states_flat))
return OptimizerState(packed_state, tree, subtrees)
@functools.wraps(update_fun)
def tree_update_fun(i, grad_tree, state_trees):
return tree_mimomap(partial(update_fun, i), grad_tree, *state_trees)
@functools.wraps(update)
def tree_update(i, grad_tree, opt_state):
packed_state, tree, subtrees = opt_state
grad_flat, tree2 = tree_flatten(grad_tree)
assert tree == tree2
states = map(tree_unflatten, subtrees, packed_state)
new_states = map(partial(update, i), grad_flat, states)
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
for subtree, subtree2 in zip(subtrees, subtrees2):
if subtree != subtree2:
msg = ("optimizer update function produced an output structure that "
"did not match its input structure: input {} and output {}.")
raise TypeError(msg.format(subtree, subtree2))
new_packed_state = pack(map(pack, new_states_flat))
return OptimizerState(new_packed_state, tree, subtrees)
return tree_init_fun, tree_update_fun
@functools.wraps(get_params)
def tree_get_params(opt_state):
packed_state, tree, subtrees = opt_state
states = map(tree_unflatten, subtrees, packed_state)
params = map(get_params, states)
return tree_unflatten(tree, params)
return tree_init, tree_update, tree_get_params
return tree_opt_maker
def iterate(state_trees):
"""Extract the current iterate from an optimizer state."""
return state_trees[0]
get_params = iterate
# optimizers
### optimizers
@optimizer
def sgd(step_size):
@ -68,14 +93,16 @@ def sgd(step_size):
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun) pair.
An (init, update) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
return (x0,)
def update_fun(i, g, x):
return (x - step_size(i) * g,)
return init_fun, update_fun
def init(x0):
return x0
def update(i, g, x):
return x - step_size(i) * g
def get_params(x):
return x
return init, update, get_params
@optimizer
def momentum(step_size, mass):
@ -86,17 +113,21 @@ def momentum(step_size, mass):
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun) pair.
An (init, update) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
def init(x0):
v0 = np.zeros_like(x0)
return x0, v0
def update_fun(i, g, x, velocity):
def update(i, g, state):
x, velocity = state
velocity = mass * velocity - (1. - mass) * g
x = x + step_size(i) * velocity
return x, velocity
return init_fun, update_fun
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def rmsprop(step_size, gamma=0.9, eps=1e-8):
@ -107,17 +138,21 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun) pair.
An (init, update) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
def init(x0):
avg_sq_grad = np.ones_like(x0)
return x0, avg_sq_grad
def update_fun(i, g, x, avg_sq_grad):
def update(i, g, state):
x, avg_sq_grad = state
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
x = x - step_size(i) * g / (np.sqrt(avg_sq_grad) + eps)
return x, avg_sq_grad
return init_fun, update_fun
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
@ -134,23 +169,28 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun) pair.
An (init, update) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
def init(x0):
m0 = np.zeros_like(x0)
v0 = np.zeros_like(x0)
return x0, m0, v0
def update_fun(i, g, x, m, v):
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
vhat = v / (1 - b2 ** (i + 1))
x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
return x, m, v
return init_fun, update_fun
def get_params(state):
x, m, v = state
return x
return init, update, get_params
# learning rate schedules
### learning rate schedules
def constant(step_size):
def schedule(i):
@ -183,10 +223,10 @@ def piecewise_constant(boundaries, values):
return values[np.sum(i > boundaries)]
return schedule
def make_schedule(scalar_or_schedule_fun):
if callable(scalar_or_schedule_fun):
return scalar_or_schedule_fun
elif np.ndim(scalar_or_schedule_fun) == 0:
return constant(scalar_or_schedule_fun)
def make_schedule(scalar_or_schedule):
if callable(scalar_or_schedule):
return scalar_or_schedule
elif np.ndim(scalar_or_schedule) == 0:
return constant(scalar_or_schedule)
else:
raise TypeError(type(scalar_or_schedule_fun))
raise TypeError(type(scalar_or_schedule))

View File

@ -224,7 +224,7 @@ class JVPTrace(Trace):
primal_out, tangent_out = build_tree(out_tree_def(), result)
return JVPTracer(self, primal_out, tangent_out)
def post_process_call(self, _, out_tracer):
def post_process_call(self, call_primitive, out_tracer, params):
out_jtuple, tree_def = tree_to_jaxtuples((out_tracer.primal, out_tracer.tangent))
master = self.master
def todo(x):

View File

@ -131,7 +131,7 @@ class BatchTrace(Trace):
val_out = call_primitive.bind(f, *vals, **params)
return BatchTracer(self, val_out, dim_out())
def post_process_call(self, _, out_tracer):
def post_process_call(self, call_primitive, out_tracer, params):
val, dim = out_tracer.val, out_tracer.batch_dim
master = self.master
def todo(x):

View File

@ -145,7 +145,7 @@ class SerialPmapTrace(Trace):
val_out = call_primitive.bind(f, *vals, **params)
return SerialPmapTracer(self, name, val_out, axis_out())
def post_process_call(self, _, out_tracer):
def post_process_call(self, call_primitive, out_tracer, params):
name, val, axis = out_tracer.name, out_tracer.val, out_tracer.axis
master = self.master
def todo(x):

View File

@ -111,7 +111,7 @@ class JaxprTrace(Trace):
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,), False, params)
return JaxprTracer(self, PartialVal((out_pv, out_const)), eqn)
def post_process_call(self, call_primitive, out_tracer):
def post_process_call(self, call_primitive, out_tracer, params):
# TODO(mattjj): post_process_map
jaxpr, consts, env = tracers_to_jaxpr([], out_tracer)
out_pv, out_pv_const = out_tracer.pval
@ -123,7 +123,7 @@ class JaxprTrace(Trace):
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, {})
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, params)
return JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), eqn)
return out, todo

View File

@ -200,6 +200,11 @@ def device_put_many(xs_and_devices):
# manipulate than simple Python builtins, we store the metadata required for
# forming the DeviceValue result in special ResultArray / ResultTuple classes.
# Every JaxType needs to map to an XLA type. However this function's design is
# based on the assumption that XLA types can be mapped uniquely back to a
# JaxType, i.e. that the mapping is bijective. That assumption could be relaxed,
# but it would mean we need to do a bit more bookkeping on the Python side to
# track abstract values of outputs.
def xla_shape_to_result_shape(xla_shape):
if xla_shape.is_tuple():
aval = aval_from_xla_shape(xla_shape)
@ -232,7 +237,8 @@ def pyval_result_handler(result_shape):
if t is ResultArray:
return lambda buf: buf.to_py()
elif t is ResultTuple:
handlers = list(map(pyval_result_handler, result_shape))
_, result_shapes = result_shape
handlers = list(map(pyval_result_handler, result_shapes))
return lambda buf: JaxTuple(h(b) for h, b in zip(handlers, buf.destructure()))
else:
raise TypeError(t)
@ -548,8 +554,9 @@ def xla_shape(x):
raise TypeError(type(x))
def xla_call_impl(fun, *args):
compiled_fun = xla_callable(fun, *map(abstractify, args))
def xla_call_impl(fun, *args, **params):
device_values = FLAGS.jax_device_values and params.pop('device_values')
compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
try:
return compiled_fun(*args)
except FloatingPointError:
@ -559,14 +566,17 @@ def xla_call_impl(fun, *args):
@lu.memoize
def xla_callable(fun, *abstract_args):
def xla_callable(fun, device_values, *abstract_args):
pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
with core.new_master(pe.JaxprTrace, True) as master:
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
assert not env # no subtraces here (though cond might eventually need them)
compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
del master, consts, jaxpr, env
handle_result = result_handler(result_shape)
if device_values:
handle_result = device_persistent_result_handler(result_shape)
else:
handle_result = pyval_result_handler(result_shape)
return partial(execute_compiled, compiled, pval, handle_result)
def execute_compiled(compiled, pval, handle_result, *args):
@ -576,7 +586,7 @@ def execute_compiled(compiled, pval, handle_result, *args):
return pe.merge_pvals(handle_result(out_buf), pval)
def xla_call_translation_rule(c, subc_a1, *a2):
def xla_call_translation_rule(c, subc_a1, *a2, **params):
subc, a1 = subc_a1
return c.Call(subc, a1 + a2)

View File

@ -152,6 +152,7 @@ def check_vjp(f, f_vjp, args, atol=ATOL, rtol=RTOL, eps=EPS):
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
args = tuple(args)
if order > 1:
def f_vjp(*args):
out_primal_py, vjp_py = api.vjp(f, *args)

View File

@ -79,71 +79,24 @@ def tree_multimap(f, tree, *rest):
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
"""
# equivalent to prefix_multimap(f, tree_structure(tree), tree, *rest)
node_type = node_types.get(type(tree))
if node_type:
children, node_spec = node_type.to_iterable(tree)
children, aux_data = node_type.to_iterable(tree)
all_children = [children]
for other_tree in rest:
other_node_type = node_types.get(type(other_tree))
# TODO(mattjj): enable this check
# if node_type != other_node_type:
# raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_spec:
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_spec))
if node_type != other_node_type:
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
other_children, other_aux_data = node_type.to_iterable(other_tree)
if other_aux_data != aux_data:
raise TypeError('Mismatch: {} != {}'.format(other_aux_data, aux_data))
all_children.append(other_children)
new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
return node_type.from_iterable(node_spec, new_children)
return node_type.from_iterable(aux_data, new_children)
else:
return f(tree, *rest)
def prefix_multimap(f, treedef, tree, *rest):
"""Like tree_multimap but only maps down through a tree prefix."""
if treedef is leaf:
return f(tree, *rest)
else:
node_type = node_types.get(type(tree))
if node_type != treedef.node_type:
raise TypeError('Mismatch: {} != {}'.format(treedef.node_type, node_type))
children, node_data = node_type.to_iterable(tree)
if node_data != treedef.node_data:
raise TypeError('Mismatch: {} != {}'.format(treedef.node_data, node_data))
all_children = [children]
for other_tree in rest:
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_data:
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_data))
all_children.append(other_children)
all_children = zip(*all_children)
new_children = [prefix_multimap(f, td, *xs)
for td, xs in zip(treedef.children, all_children)]
return node_type.from_iterable(node_data, new_children)
def tree_mimomap(f, tree, *rest):
"""Map a multi-input tuple-output over pytree args to form a tuple of pytrees.
Args:
f: function that takes `1 + len(rest)` arguments and returns a tuple, to be
applied at the corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
*rest: a tuple of pytrees, each with the same structure as `tree`.
Returns:
A tuple of pytrees with length given by the length of the output of `f` and
with each pytree element having the same structure as `tree`.
"""
flat, treedef = tree_flatten(tree)
rest_flat, treedefs = unzip2(map(tree_flatten, rest))
if not all(td == treedef for td in treedefs):
td = next(td for td in treedefs if td != treedef)
raise TypeError('Mismatch: {} != {}'.format(treedef, td))
out_flat = zip(*map(f, flat, *rest_flat))
return tuple(map(partial(tree_unflatten, treedef), out_flat))
def tree_reduce(f, tree):
flat, _ = tree_flatten(tree)
@ -270,6 +223,9 @@ class NodeType(object):
self.to_iterable = to_iterable
self.from_iterable = from_iterable
def __repr__(self):
return self.name
node_types = {}
def register_pytree_node(py_type, to_iterable, from_iterable):

View File

@ -188,19 +188,19 @@
"metadata": {},
"outputs": [],
"source": [
"opt_init, opt_update = optimizers.adam(step_size=1e-2)\n",
"opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)\n",
"opt_state = opt_init(net_params)\n",
"\n",
"# Define a compiled update step\n",
"@jit\n",
"def step(i, opt_state, x1, y1):\n",
" p = optimizers.get_params(opt_state)\n",
" p = get_params(opt_state)\n",
" g = grad(loss)(p, x1, y1)\n",
" return opt_update(i, g, opt_state)\n",
"\n",
"for i in range(100):\n",
" opt_state = step(i, opt_state, xrange_inputs, targets)\n",
"net_params = optimizers.get_params(opt_state)"
"net_params = get_params(opt_state)"
]
},
{

View File

@ -23,12 +23,16 @@ from absl.testing import absltest
import jax.numpy as np
import jax.test_util as jtu
from jax import jit, grad
from jax import core, tree_util
from jax.experimental import optimizers
from jax.lib import xla_bridge as xla
from jax.interpreters import xla
from jax.config import config
config.parse_flags_with_absl()
dummy_data = None
class OptimizerTests(jtu.JaxTestCase):
def _CheckOptimizer(self, optimizer, loss, x0, num_steps, *args, **kwargs):
@ -36,29 +40,33 @@ class OptimizerTests(jtu.JaxTestCase):
self._CheckRun(optimizer, loss, x0, num_steps, *args, **kwargs)
def _CheckFuns(self, optimizer, loss, x0, *args):
init_fun, update_fun = optimizer(*args)
init_fun, update_fun, get_params_fun = optimizer(*args)
opt_state = init_fun(x0)
update_fun(0, grad(loss)(x0, None), opt_state) # doesn't crash
self.assertAllClose(x0, get_params_fun(opt_state), check_dtypes=True)
opt_state2 = update_fun(0, grad(loss)(x0, dummy_data), opt_state) # doesn't crash
self.assertEqual(tree_util.tree_structure(opt_state),
tree_util.tree_structure(opt_state2))
@jtu.skip_on_devices('gpu')
def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs):
return # TODO(mattjj): bring back fax!
# num_repl = xla.get_replica_count()
# infeeder = fax.make_infeed_from_sequence(
# [np.ones(1, dtype='float32')] * num_steps * num_repl,
# with_pyvals=True)
init_fun, update_fun, get_params_fun = optimizer(*args)
# def op(infeed, x0):
# opt_init, opt_update = optimizer(*args, **kwargs)
# return optimizers.run_optimizer(loss, infeed, opt_update, opt_init(x0))
# cop = jit(op)
opt_state = init_fun(x0)
for i in range(num_steps):
x = get_params_fun(opt_state)
g = grad(loss)(x, dummy_data)
opt_state = update_fun(i, g, opt_state)
xstar = get_params_fun(opt_state)
self.assertLess(loss(xstar, dummy_data), 1e-2)
# a1, _ = op(infeeder(), x0)
# a2, _ = cop(infeeder(), x0)
# assert loss(a1, None) < 1e-3
# assert loss(a2, None) < 1e-3
# self.assertAllClose(a1, a2, check_dtypes=False)
update_fun_jitted = jit(update_fun)
opt_state = init_fun(x0)
for i in range(num_steps):
x = get_params_fun(opt_state)
g = grad(loss)(x, dummy_data)
opt_state = update_fun_jitted(i, g, opt_state)
xstar = get_params_fun(opt_state)
self.assertLess(loss(xstar, dummy_data), 1e-2)
def testSgdScalar(self):
def loss(x, _): return x**2
@ -91,6 +99,14 @@ class OptimizerTests(jtu.JaxTestCase):
mass = 0.
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
def testMomentumDict(self):
def loss(dct, _): return np.dot(dct['x'], dct['x'])
x0 = {'x': np.ones(2)}
num_iters = 100
step_size = 0.1
mass = 0.
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)
def testRmspropVector(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
@ -118,57 +134,76 @@ class OptimizerTests(jtu.JaxTestCase):
def testSgdVectorExponentialDecaySchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule)
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
def testSgdVectorInverseTimeDecaySchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule)
self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
def testAdamVectorInverseTimeDecaySchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
self._CheckOptimizer(optimizers.adam, loss, x0, num_iters, step_schedule)
self._CheckFuns(optimizers.adam, loss, x0, step_schedule)
def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_sched = optimizers.inverse_time_decay(0.1, 3, 2., staircase=True)
mass = 0.9
self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_sched, mass)
self._CheckFuns(optimizers.momentum, loss, x0, step_sched, mass)
def testRmspropVectorPiecewiseConstantSchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_schedule)
self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)
def testTracedStepSize(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_size = 0.1
init_fun, _ = optimizers.sgd(step_size)
init_fun, _, _ = optimizers.sgd(step_size)
opt_state = init_fun(x0)
@jit
def update(opt_state, step_size):
_, update_fun = optimizers.sgd(step_size)
x = optimizers.get_params(opt_state)
_, update_fun, get_params = optimizers.sgd(step_size)
x = get_params(opt_state)
g = grad(loss)(x, None)
return update_fun(0, g, opt_state)
update(opt_state, 0.9) # doesn't crash
def testDeviceTupleState(self):
init_fun, update_fun, _ = optimizers.sgd(0.1)
opt_state = init_fun(np.zeros(3))
self.assertIsInstance(opt_state, optimizers.OptimizerState)
self.assertIsInstance(opt_state.packed_state, core.JaxTuple)
opt_state = jit(update_fun)(0, np.zeros(3), opt_state)
self.assertIsInstance(opt_state, optimizers.OptimizerState)
self.assertIsInstance(opt_state.packed_state, xla.DeviceTuple)
def testUpdateFunStructureMismatchErrorMessage(self):
@optimizers.optimizer
def opt_maker():
def init_fun(x0):
return {'x': x0}
def update_fun(i, g, opt_state):
x = opt_state['x']
return {'x': x - 0.1 * g, 'v': g} # bug!
def get_params(opt_state):
return opt_state['x']
return init_fun, update_fun, get_params
init_fun, update_fun, get_params = opt_maker()
opt_state = init_fun(np.zeros(3))
self.assertRaises(TypeError, lambda: update_fun(opt_state))
if __name__ == '__main__':
absltest.main()