sync updates

This commit is contained in:
Matthew Johnson 2018-11-19 07:43:23 -08:00
parent ae641fdaec
commit 46c6a9170f
19 changed files with 552 additions and 113 deletions

View File

@ -20,9 +20,11 @@ from absl import app
import IPython
import numpy as onp
import jax
import jax.numpy as np
from jax import lax
from jax import numpy as np
from jax import jit, grad, vmap
from jax import random
from jax import jit, grad, vmap, jacfwd, jacrev, hessian
def main(unused_argv):

View File

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
"""A basic MNIST example using JAX together with the mini-libraries stax, for
neural network building, and minmax, for first-order stochastic optimization.
"""
from __future__ import absolute_import
@ -22,26 +21,19 @@ from __future__ import division
from __future__ import print_function
import time
import itertools
from absl import app
import numpy.random as npr
from jax.api import jit, grad
from jax.examples import datasets
from jax.scipy.misc import logsumexp
import jax.numpy as np
from jax import jit, grad
from jax.experimental import minmax
from jax.examples import datasets
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, Softmax
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
def predict(params, inputs):
for w, b in params:
outputs = np.dot(inputs, w) + b
inputs = np.tanh(outputs)
return outputs - logsumexp(outputs, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
@ -53,13 +45,16 @@ def accuracy(params, batch):
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), Softmax)
def main(unused_argv):
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 32
momentum_mass = 0.9
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
@ -75,19 +70,24 @@ def main(unused_argv):
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
@jit
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
opt_init, opt_update = minmax.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = minmax.get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params((-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
params = init_random_params(param_scale, layer_sizes)
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
params = update(params, next(batches))
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = minmax.get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))

View File

@ -0,0 +1,99 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import app
import numpy.random as npr
from jax.api import jit, grad
from jax.examples import datasets
from jax.scipy.misc import logsumexp
import jax.numpy as np
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
def predict(params, inputs):
for w, b in params:
outputs = np.dot(inputs, w) + b
inputs = np.tanh(outputs)
return outputs - logsumexp(outputs, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
def main(unused_argv):
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 32
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
@jit
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
params = init_random_params(param_scale, layer_sizes)
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
params = update(params, next(batches))
epoch_time = time.time() - start_time
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -28,13 +28,12 @@ import time
from absl import app
import matplotlib.pyplot as plt
from jax import lax, random
from jax.api import jit, grad
import jax.numpy as np
from jax import jit, grad, lax, random
from jax.examples import datasets
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
import jax.numpy as np
def gaussian_kl(mu, sigmasq):

134
examples/resnet50.py Normal file
View File

@ -0,0 +1,134 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A mock-up showing a ResNet50 network with training on synthetic data.
This file uses the stax neural network definition library and the minmax
optimization library.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import numpy.random as npr
import jax.numpy as np
from jax import jit, grad
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
FanOut, Flatten, GeneralConv, Identity,
MaxPool, Relu, Softmax)
# ResNet blocks compose other layers
def ConvBlock(kernel_size, filters, strides=(2, 2)):
ks = kernel_size
filters1, filters2, filters3 = filters
Main = stax.serial(
Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
Conv(filters3, (1, 1)), BatchNorm())
Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def IdentityBlock(kernel_size, filters):
ks = kernel_size
filters1, filters2 = filters
def make_main(input_shape):
# the number of output channels depends on the number of input channels
return stax.serial(
Conv(filters1, (1, 1)), BatchNorm(), Relu,
Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
Conv(input_shape[3], (1, 1)), BatchNorm())
Main = stax.shape_dependent(make_main)
return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
# ResNet architectures compose layers and ResNet blocks
def ResNet50(num_classes):
return stax.serial(
GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
ConvBlock(3, [64, 64, 256], strides=(1, 1)),
IdentityBlock(3, [64, 64]),
IdentityBlock(3, [64, 64]),
ConvBlock(3, [128, 128, 512]),
IdentityBlock(3, [128, 128]),
IdentityBlock(3, [128, 128]),
IdentityBlock(3, [128, 128]),
ConvBlock(3, [256, 256, 1024]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
ConvBlock(3, [512, 512, 2048]),
IdentityBlock(3, [512, 512]),
IdentityBlock(3, [512, 512]),
AvgPool((7, 7)), Flatten, Dense(num_classes), Softmax)
def main(argv):
del argv # Unused.
batch_size = 8
num_classes = 1001
input_shape = (224, 224, 3, batch_size)
step_size = 0.1
num_steps = 10
init_fun, predict_fun = ResNet50(num_classes)
_, init_params = init_fun(input_shape)
def loss(params, batch):
inputs, targets = batch
logits = predict_fun(params, inputs)
return np.sum(logits * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=-1)
predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
return np.mean(predicted_class == target_class)
def synth_batches():
rng = npr.RandomState(0)
while True:
images = rng.rand(*input_shape).astype('float32')
labels = rng.randint(num_classes, size=(batch_size, 1))
onehot_labels = labels == np.arange(num_classes)
yield images, onehot_labels
opt_init, opt_update = minmax.momentum(step_size, mass=0.9)
batches = synth_batches()
@jit
def update(i, opt_state, batch):
params = minmax.get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
opt_state = opt_init(init_params)
for i in xrange(num_steps):
opt_state = update(i, opt_state, next(batches))
trained_params = minmax.get_params(opt_state)
if __name__ == '__main__':
app.run(main)

View File

@ -428,9 +428,9 @@ pytype_aval_mappings = {}
# ------------------- Products -------------------
class JaxTuple(tuple):
def __new__(self, xs):
assert all(map(valid_jaxtype, xs)), xs
return tuple.__new__(JaxTuple, xs)
def __new__(cls, xs):
assert skip_checks or all(map(valid_jaxtype, xs)), xs
return tuple.__new__(cls, xs)
def __repr__(self):
if self is unit:

View File

@ -23,21 +23,28 @@ from __future__ import division
from __future__ import print_function
import operator
import functools
import jax.numpy as np
from jax.core import pack
from jax.tree_util import tree_map, tree_multimap
import jax.numpy as np
def optimizer(opt_maker):
"""Decorator to make an optimizer map over tuple/list/dict containers."""
@functools.wraps(opt_maker)
def tree_opt_maker(*args, **kwargs):
init_fun, update_fun = opt_maker(*args, **kwargs)
@functools.wraps(init_fun)
def fmapped_init_fun(x0_tree):
return tree_map(lambda x0: pack(init_fun(x0)), x0_tree)
@functools.wraps(update_fun)
def fmapped_update_fun(i, grad_tree, state_tree):
update = lambda g, state: pack(update_fun(i, g, *state))
return tree_multimap(update, grad_tree, state_tree)
return fmapped_init_fun, fmapped_update_fun
return tree_opt_maker
@ -46,42 +53,86 @@ def iterate(state_tree):
return tree_map(lambda state: tuple(state)[0], state_tree)
get_params = iterate
# optimizers
@optimizer
def sgd(step_size):
"""Init and update step functions for stochastic gradient descent."""
"""Construct init and update step functions for stochastic gradient descent.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
return (x0,)
def update_fun(i, g, x):
return (x - step_size * g,)
return (x - step_size(i) * g,)
return init_fun, update_fun
@optimizer
def momentum(step_size, mass):
"""Init and update step functions for SGD with Nesterov momentum."""
"""Construct init and update step functions for SGD with Nesterov momentum.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
v0 = np.zeros_like(x0)
return x0, v0
def update_fun(i, g, x, velocity):
velocity = mass * velocity - (1. - mass) * g
x = x + step_size * velocity
x = x + step_size(i) * velocity
return x, velocity
return init_fun, update_fun
@optimizer
def rmsprop(step_size, gamma=0.9, eps=1e-8):
"""Init and update step functions for RMSProp."""
"""Construct init and update step functions for RMSProp.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
avg_sq_grad = np.ones_like(x0)
return x0, avg_sq_grad
def update_fun(i, g, x, avg_sq_grad):
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
x = x - step_size * g / (np.sqrt(avg_sq_grad) + eps)
x = x - step_size(i) * g / (np.sqrt(avg_sq_grad) + eps)
return x, avg_sq_grad
return init_fun, update_fun
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Init and update step functions for Adam."""
"""Construct init and update step functions for Adam.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun) pair.
"""
step_size = make_schedule(step_size)
def init_fun(x0):
m0 = np.zeros_like(x0)
v0 = np.zeros_like(x0)
@ -91,26 +142,47 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
vhat = v / (1 - b2 ** (i + 1))
x = x - step_size*mhat / (np.sqrt(vhat) + eps)
x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
return x, m, v
return init_fun, update_fun
def run_optimizer(loss, infeed, update_fun, state):
"""A convenience function for running optimizers with iterated map-reduce.
# learning rate schedules
Args:
loss: a scalar-valued loss function taking two aguments, the current iterate
and a data value.
infeed: an infeed instance supplying the data stream.
update_fun: a function that has signature update_fun(i, grad, state) where
i is the integer iteration count, grad is the gradient of the loss at the
current iterate, and state is the current optimizer state.
state: the initial optimizer state.
def constant(step_size):
def schedule(i):
return step_size
return schedule
Returns:
A pair (x, state) where is the final iterate and state represents the final
optimizer state.
"""
map_fun = lambda _, state, batch: grad(loss)(iterate(state), batch)
state = fax.iterated_map_reduce(state, map_fun, update_fun, infeed)
return iterate(state), state
def exponential_decay(step_size, decay_steps, decay_rate):
def schedule(i):
return step_size * decay_rate ** (i / decay_steps)
return schedule
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
if staircase:
def schedule(i):
return step_size / (1 + decay_rate * np.floor(i / decay_steps))
else:
def schedule(i):
return step_size / (1 + decay_rate * i / decay_steps)
return schedule
def piecewise_constant(boundaries, values):
boundaries = np.array(boundaries)
values = np.array(values)
if not boundaries.ndim == values.ndim == 1:
raise ValueError("boundaries and values must be sequences")
if not boundaries.shape[0] == values.shape[0] - 1:
raise ValueError("boundaries length must be one longer than values length")
def schedule(i):
return values[np.sum(i > boundaries)]
return schedule
def make_schedule(constant_scalar_or_schedule_fun):
if np.isscalar(constant_scalar_or_schedule_fun):
return constant(constant_scalar_or_schedule_fun)
elif callable(constant_scalar_or_schedule_fun):
return constant_scalar_or_schedule_fun
else:
raise TypeError, type(constant_scalar_or_schedule_fun)

View File

@ -173,12 +173,12 @@ class JaxprTracerTuple(tuple): pass
Destructuring = namedtuple('Destructuring', ['i', 'eqn', 'key'])
class PartialVal(tuple):
def __init__(self, xs):
def __new__(cls, xs):
assert core.skip_checks or (
isinstance(xs[0], valid_pv_types)
and isinstance(xs[1], core.Tracer) or core.valid_jaxtype(xs[1])
), xs
super(PartialVal, self).__init__(xs)
return tuple.__new__(cls, xs)
valid_pv_types = (AbstractValue, JaxprTracerTuple, type(None))

View File

@ -1566,9 +1566,30 @@ def slice_transpose_rule(t, start_indices, limit_indices, strides,
assert result.shape == operand_shape
return [result]
def slice_batching_rule(batched_args, batch_dims, start_indices, limit_indices,
strides, **unused_kwargs):
operand, = batched_args
bdim, = batch_dims
new_start_indices = list(start_indices)
new_start_indices.insert(bdim, 0)
new_limit_indices = list(limit_indices)
new_limit_indices.insert(bdim, operand.shape[bdim])
if strides is None:
new_strides = None
else:
new_strides = list(strides)
new_strides.insert(bdim, 1)
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
return out, bdim
slice_p = standard_primitive(slice_shape_rule, _input_dtype, 'slice',
slice_translation_rule)
ad.deflinear(slice_p, slice_transpose_rule)
batching.primitive_batchers[slice_p] = slice_batching_rule
def dynamic_slice_shape_rule(operand, start_indices, slice_sizes,

View File

@ -171,7 +171,7 @@ def _constant_like(x, const):
def _wraps(fun):
"""Like functools.wraps but works with numpy.ufuncs."""
docstr = """
LAX-backed implementation of {fun}. Corresponding Numpy docstring below.
LAX-backed implementation of {fun}. Original docstring below.
{np_doc}
""".format(fun=fun.__name__, np_doc=fun.__doc__)

View File

@ -309,7 +309,7 @@ def shuffle(key, x, axis=0):
for _ in range(num_rounds):
key, subkey = split(key)
sort_keys = _random_bits(subkey, 32, x.shape)
_, x = lax.sort_keyval(sort_keys, x, axis)
_, x = lax.sort_key_val(sort_keys, x, axis)
return x

View File

@ -20,22 +20,7 @@ import numpy as onp
import scipy.misc as osp_misc
from .. import lax
def _wraps(fun):
"""Like functools.wraps but works with numpy.ufuncs."""
docstr = """
LAX-backed implementation of {fun}. Corresponding Scipy docstring below.
{np_doc}
""".format(fun=fun.__name__, np_doc=fun.__doc__)
def wrap(op):
try:
op.__name__ = fun.__name__
op.__doc__ = docstr
finally:
return op
return wrap
from ..numpy.lax_numpy import _wraps, _reduction_dims, _constant_like
@_wraps(osp_misc.logsumexp)
@ -50,19 +35,3 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
_constant_like(a, 0), lax.add, dims)), amax)
return dimadd(out) if keepdims else out
# TODO(mattjj): this is duplicated from lax_numpy.py
def _reduction_dims(a, axis):
if axis is None:
return onp.arange(onp.ndim(a))
elif isinstance(axis, (onp.ndarray, tuple, list)):
return onp.mod(onp.asarray(axis), onp.ndim(a))
elif isinstance(axis, int):
return onp.mod([axis], onp.ndim(a))
else:
raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
def _constant_like(x, const):
return onp.array(const, dtype=lax._dtype(x))

View File

@ -19,22 +19,7 @@ from __future__ import print_function
import scipy.special as osp_special
from .. import lax
def _wraps(fun):
"""Like functools.wraps but works with numpy.ufuncs."""
docstr = """
LAX-backed implementation of {fun}. Corresponding Scipy docstring below.
{np_doc}
""".format(fun=fun.__name__, np_doc=fun.__doc__)
def wrap(op):
try:
op.__name__ = fun.__name__
op.__doc__ = docstr
finally:
return op
return wrap
from ..numpy.lax_numpy import _wraps
gammaln = _wraps(osp_special.gammaln)(lax.lgamma)

View File

@ -0,0 +1,16 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import norm

33
jax/scipy/stats/norm.py Normal file
View File

@ -0,0 +1,33 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as onp
import scipy.stats as osp_stats
from ... import lax
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
@_wraps(osp_stats.norm.logpdf)
def logpdf(x, loc=0, scale=1):
x, loc, scale = _promote_args_like(osp_stats.norm.logpdf, x, loc, scale)
two = _constant_like(x, 2)
scale_sqrd = lax.pow(scale, two)
log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * onp.pi), scale_sqrd))
quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)

View File

@ -21,6 +21,7 @@ from absl.testing import parameterized
import jax.numpy as np
from jax import test_util as jtu
from jax.abstract_arrays import ShapedArray
from jax import lax
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr
from jax.api import vmap
from jax.core import unit
@ -135,6 +136,24 @@ class BatchingTest(jtu.JaxTestCase):
check_dtypes=False)
self.assertEqual(len(side), 1)
def testSliceLax(self):
fun = lambda x: lax.slice(x, (2,), (4,))
R = onp.random.RandomState(0).randn
x = R(5, 10)
ans = vmap(fun, x)
expected_ans = x[:, 2:4]
self.assertAllClose(ans, expected_ans, check_dtypes=False)
def testSliceNumpy(self):
fun = lambda x: x[:, 2]
R = onp.random.RandomState(0).randn
x = R(10, 5, 3, 7)
ans = vmap(fun, x)
expected_ans = x[:, :, 2]
self.assertAllClose(ans, expected_ans, check_dtypes=False)
if __name__ == '__main__':
absltest.main()

View File

@ -27,15 +27,17 @@ from absl.testing import parameterized
import numpy as onp
import scipy.misc as osp_misc
import scipy.special as osp_special
import scipy.stats as osp_stats
from jax import api
from jax import test_util as jtu
from jax.scipy import misc as lsp_misc
from jax.scipy import special as lsp_special
from jax.scipy import stats as lsp_stats
FLAGS = flags.FLAGS
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
float_dtypes = [onp.float32, onp.float64]
complex_dtypes = [onp.complex64]
@ -79,6 +81,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for axis in range(-len(shape), len(shape))
for keepdims in [False, True])
def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
# TODO(mattjj): test autodiff
def scipy_fun(array_to_reduce):
return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
@ -101,6 +104,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, modes):
# TODO(mattjj): unskip this test combination when real() on tpu is improved
# TODO(mattjj): test autodiff
if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
and not shapes[0]):
return absltest.unittest.skip("real() on scalar not supported on tpu")
@ -111,7 +115,40 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
check_dtypes=False)
self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)
# TODO(mattjj): add test for lsp_stats.norm_logpdf
@parameterized.named_parameters(
{"testcase_name": jtu.format_test_name_suffix(
"", shapes, dtypes),
"rng": rng, "shapes": shapes, "dtypes": dtypes}
for shapes in CombosWithReplacement(all_shapes, 3)
for dtypes in CombosWithReplacement(default_dtypes, 3)
for rng in [jtu.rand_default()])
def testNormLogPdfThreeArgs(self, rng, shapes, dtypes):
# TODO(mattjj): test autodiff
scipy_fun = osp_stats.norm.logpdf
lax_fun = lsp_stats.norm.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
scale = 0.5 + onp.abs(scale)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": jtu.format_test_name_suffix(
"", shapes, dtypes),
"rng": rng, "shapes": shapes, "dtypes": dtypes}
for shapes in CombosWithReplacement(all_shapes, 2)
for dtypes in CombosWithReplacement(default_dtypes, 2)
for rng in [jtu.rand_default()])
def testNormLogPdfTwoArgs(self, rng, shapes, dtypes):
# TODO(mattjj): test autodiff
scale = 0.5
scipy_fun = functools.partial(osp_stats.norm.logpdf, scale=scale)
lax_fun = functools.partial(lsp_stats.norm.logpdf, scale=scale)
def args_maker():
return list(map(rng, shapes, dtypes))
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
if __name__ == "__main__":

View File

@ -114,5 +114,41 @@ class OptimizerTests(jtu.JaxTestCase):
partial_loss = functools.partial(loss, y)
self._CheckRun(minmax.sgd, partial_loss, x0, num_iters, step_size)
def testSgdVectorExponentialDecaySchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = minmax.exponential_decay(0.1, 3, 2.)
self._CheckOptimizer(minmax.sgd, loss, x0, num_iters, step_schedule)
def testSgdVectorInverseTimeDecaySchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = minmax.inverse_time_decay(0.1, 3, 2.)
self._CheckOptimizer(minmax.sgd, loss, x0, num_iters, step_schedule)
def testAdamVectorInverseTimeDecaySchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = minmax.inverse_time_decay(0.1, 3, 2.)
self._CheckOptimizer(minmax.adam, loss, x0, num_iters, step_schedule)
def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_sched = minmax.inverse_time_decay(0.1, 3, 2., staircase=True)
mass = 0.9
self._CheckOptimizer(minmax.momentum, loss, x0, num_iters, step_sched, mass)
def testRmspropVectorPiecewiseConstantSchedule(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_schedule = minmax.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
self._CheckOptimizer(minmax.rmsprop, loss, x0, num_iters, step_schedule)
if __name__ == '__main__':
absltest.main()

View File

@ -127,6 +127,23 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
@parameterized.named_parameters(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64])
def testShuffle(self, dtype):
key = random.PRNGKey(0)
x = onp.arange(100).astype(dtype)
rand = lambda key: random.shuffle(key, x)
crand = api.jit(rand)
perm1 = rand(key)
perm2 = crand(key)
self.assertTrue(onp.all(perm1 == perm2))
self.assertTrue(onp.all(perm1.dtype == perm2.dtype))
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertTrue(onp.all(onp.sort(perm1) == x))
if __name__ == "__main__":
absltest.main()