mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
sync updates
This commit is contained in:
parent
ae641fdaec
commit
46c6a9170f
@ -20,9 +20,11 @@ from absl import app
|
||||
import IPython
|
||||
import numpy as onp
|
||||
|
||||
import jax
|
||||
import jax.numpy as np
|
||||
from jax import lax
|
||||
from jax import numpy as np
|
||||
from jax import jit, grad, vmap
|
||||
from jax import random
|
||||
from jax import jit, grad, vmap, jacfwd, jacrev, hessian
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
|
@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A basic MNIST example using Numpy and JAX.
|
||||
|
||||
The primary aim here is simplicity and minimal dependencies.
|
||||
"""A basic MNIST example using JAX together with the mini-libraries stax, for
|
||||
neural network building, and minmax, for first-order stochastic optimization.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -22,26 +21,19 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import itertools
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
from jax.examples import datasets
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.examples import datasets
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, Softmax
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
return [(scale * rng.randn(m, n), scale * rng.randn(n))
|
||||
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
|
||||
|
||||
def predict(params, inputs):
|
||||
for w, b in params:
|
||||
outputs = np.dot(inputs, w) + b
|
||||
inputs = np.tanh(outputs)
|
||||
return outputs - logsumexp(outputs, axis=1, keepdims=True)
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
@ -53,13 +45,16 @@ def accuracy(params, batch):
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
init_random_params, predict = stax.serial(
|
||||
Dense(1024), Relu,
|
||||
Dense(1024), Relu,
|
||||
Dense(10), Softmax)
|
||||
|
||||
def main(unused_argv):
|
||||
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
|
||||
param_scale = 0.1
|
||||
step_size = 0.001
|
||||
num_epochs = 10
|
||||
batch_size = 32
|
||||
momentum_mass = 0.9
|
||||
|
||||
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
||||
num_train = train_images.shape[0]
|
||||
@ -75,19 +70,24 @@ def main(unused_argv):
|
||||
yield train_images[batch_idx], train_labels[batch_idx]
|
||||
batches = data_stream()
|
||||
|
||||
@jit
|
||||
def update(params, batch):
|
||||
grads = grad(loss)(params, batch)
|
||||
return [(w - step_size * dw, b - step_size * db)
|
||||
for (w, b), (dw, db) in zip(params, grads)]
|
||||
opt_init, opt_update = minmax.momentum(step_size, mass=momentum_mass)
|
||||
|
||||
@jit
|
||||
def update(i, opt_state, batch):
|
||||
params = minmax.get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
_, init_params = init_random_params((-1, 28 * 28))
|
||||
opt_state = opt_init(init_params)
|
||||
itercount = itertools.count()
|
||||
|
||||
params = init_random_params(param_scale, layer_sizes)
|
||||
for epoch in range(num_epochs):
|
||||
start_time = time.time()
|
||||
for _ in range(num_batches):
|
||||
params = update(params, next(batches))
|
||||
opt_state = update(next(itercount), opt_state, next(batches))
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
params = minmax.get_params(opt_state)
|
||||
train_acc = accuracy(params, (train_images, train_labels))
|
||||
test_acc = accuracy(params, (test_images, test_labels))
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
|
99
examples/mnist_classifier_fromscratch.py
Normal file
99
examples/mnist_classifier_fromscratch.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A basic MNIST example using Numpy and JAX.
|
||||
|
||||
The primary aim here is simplicity and minimal dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
from jax.examples import datasets
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
return [(scale * rng.randn(m, n), scale * rng.randn(n))
|
||||
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
|
||||
|
||||
def predict(params, inputs):
|
||||
for w, b in params:
|
||||
outputs = np.dot(inputs, w) + b
|
||||
inputs = np.tanh(outputs)
|
||||
return outputs - logsumexp(outputs, axis=1, keepdims=True)
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
|
||||
param_scale = 0.1
|
||||
step_size = 0.001
|
||||
num_epochs = 10
|
||||
batch_size = 32
|
||||
|
||||
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
||||
num_train = train_images.shape[0]
|
||||
num_complete_batches, leftover = divmod(num_train, batch_size)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
|
||||
def data_stream():
|
||||
rng = npr.RandomState(0)
|
||||
while True:
|
||||
perm = rng.permutation(num_train)
|
||||
for i in range(num_batches):
|
||||
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
|
||||
yield train_images[batch_idx], train_labels[batch_idx]
|
||||
batches = data_stream()
|
||||
|
||||
@jit
|
||||
def update(params, batch):
|
||||
grads = grad(loss)(params, batch)
|
||||
return [(w - step_size * dw, b - step_size * db)
|
||||
for (w, b), (dw, db) in zip(params, grads)]
|
||||
|
||||
params = init_random_params(param_scale, layer_sizes)
|
||||
for epoch in range(num_epochs):
|
||||
start_time = time.time()
|
||||
for _ in range(num_batches):
|
||||
params = update(params, next(batches))
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
train_acc = accuracy(params, (train_images, train_labels))
|
||||
test_acc = accuracy(params, (test_images, test_labels))
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -28,13 +28,12 @@ import time
|
||||
from absl import app
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from jax import lax, random
|
||||
from jax.api import jit, grad
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad, lax, random
|
||||
from jax.examples import datasets
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
import jax.numpy as np
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
|
134
examples/resnet50.py
Normal file
134
examples/resnet50.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A mock-up showing a ResNet50 network with training on synthetic data.
|
||||
|
||||
This file uses the stax neural network definition library and the minmax
|
||||
optimization library.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
|
||||
FanOut, Flatten, GeneralConv, Identity,
|
||||
MaxPool, Relu, Softmax)
|
||||
|
||||
|
||||
# ResNet blocks compose other layers
|
||||
|
||||
def ConvBlock(kernel_size, filters, strides=(2, 2)):
|
||||
ks = kernel_size
|
||||
filters1, filters2, filters3 = filters
|
||||
Main = stax.serial(
|
||||
Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
|
||||
Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
|
||||
Conv(filters3, (1, 1)), BatchNorm())
|
||||
Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
|
||||
return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
|
||||
|
||||
|
||||
def IdentityBlock(kernel_size, filters):
|
||||
ks = kernel_size
|
||||
filters1, filters2 = filters
|
||||
def make_main(input_shape):
|
||||
# the number of output channels depends on the number of input channels
|
||||
return stax.serial(
|
||||
Conv(filters1, (1, 1)), BatchNorm(), Relu,
|
||||
Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
|
||||
Conv(input_shape[3], (1, 1)), BatchNorm())
|
||||
Main = stax.shape_dependent(make_main)
|
||||
return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
|
||||
|
||||
|
||||
# ResNet architectures compose layers and ResNet blocks
|
||||
|
||||
def ResNet50(num_classes):
|
||||
return stax.serial(
|
||||
GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
|
||||
BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
|
||||
ConvBlock(3, [64, 64, 256], strides=(1, 1)),
|
||||
IdentityBlock(3, [64, 64]),
|
||||
IdentityBlock(3, [64, 64]),
|
||||
ConvBlock(3, [128, 128, 512]),
|
||||
IdentityBlock(3, [128, 128]),
|
||||
IdentityBlock(3, [128, 128]),
|
||||
IdentityBlock(3, [128, 128]),
|
||||
ConvBlock(3, [256, 256, 1024]),
|
||||
IdentityBlock(3, [256, 256]),
|
||||
IdentityBlock(3, [256, 256]),
|
||||
IdentityBlock(3, [256, 256]),
|
||||
IdentityBlock(3, [256, 256]),
|
||||
IdentityBlock(3, [256, 256]),
|
||||
ConvBlock(3, [512, 512, 2048]),
|
||||
IdentityBlock(3, [512, 512]),
|
||||
IdentityBlock(3, [512, 512]),
|
||||
AvgPool((7, 7)), Flatten, Dense(num_classes), Softmax)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
batch_size = 8
|
||||
num_classes = 1001
|
||||
input_shape = (224, 224, 3, batch_size)
|
||||
step_size = 0.1
|
||||
num_steps = 10
|
||||
|
||||
init_fun, predict_fun = ResNet50(num_classes)
|
||||
_, init_params = init_fun(input_shape)
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
logits = predict_fun(params, inputs)
|
||||
return np.sum(logits * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=-1)
|
||||
predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
def synth_batches():
|
||||
rng = npr.RandomState(0)
|
||||
while True:
|
||||
images = rng.rand(*input_shape).astype('float32')
|
||||
labels = rng.randint(num_classes, size=(batch_size, 1))
|
||||
onehot_labels = labels == np.arange(num_classes)
|
||||
yield images, onehot_labels
|
||||
|
||||
opt_init, opt_update = minmax.momentum(step_size, mass=0.9)
|
||||
batches = synth_batches()
|
||||
|
||||
@jit
|
||||
def update(i, opt_state, batch):
|
||||
params = minmax.get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
opt_state = opt_init(init_params)
|
||||
for i in xrange(num_steps):
|
||||
opt_state = update(i, opt_state, next(batches))
|
||||
trained_params = minmax.get_params(opt_state)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
@ -428,9 +428,9 @@ pytype_aval_mappings = {}
|
||||
# ------------------- Products -------------------
|
||||
|
||||
class JaxTuple(tuple):
|
||||
def __new__(self, xs):
|
||||
assert all(map(valid_jaxtype, xs)), xs
|
||||
return tuple.__new__(JaxTuple, xs)
|
||||
def __new__(cls, xs):
|
||||
assert skip_checks or all(map(valid_jaxtype, xs)), xs
|
||||
return tuple.__new__(cls, xs)
|
||||
|
||||
def __repr__(self):
|
||||
if self is unit:
|
||||
|
@ -23,21 +23,28 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
import functools
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.core import pack
|
||||
from jax.tree_util import tree_map, tree_multimap
|
||||
import jax.numpy as np
|
||||
|
||||
|
||||
def optimizer(opt_maker):
|
||||
"""Decorator to make an optimizer map over tuple/list/dict containers."""
|
||||
@functools.wraps(opt_maker)
|
||||
def tree_opt_maker(*args, **kwargs):
|
||||
init_fun, update_fun = opt_maker(*args, **kwargs)
|
||||
|
||||
@functools.wraps(init_fun)
|
||||
def fmapped_init_fun(x0_tree):
|
||||
return tree_map(lambda x0: pack(init_fun(x0)), x0_tree)
|
||||
|
||||
@functools.wraps(update_fun)
|
||||
def fmapped_update_fun(i, grad_tree, state_tree):
|
||||
update = lambda g, state: pack(update_fun(i, g, *state))
|
||||
return tree_multimap(update, grad_tree, state_tree)
|
||||
|
||||
return fmapped_init_fun, fmapped_update_fun
|
||||
return tree_opt_maker
|
||||
|
||||
@ -46,42 +53,86 @@ def iterate(state_tree):
|
||||
return tree_map(lambda state: tuple(state)[0], state_tree)
|
||||
get_params = iterate
|
||||
|
||||
# optimizers
|
||||
|
||||
@optimizer
|
||||
def sgd(step_size):
|
||||
"""Init and update step functions for stochastic gradient descent."""
|
||||
"""Construct init and update step functions for stochastic gradient descent.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
return (x0,)
|
||||
def update_fun(i, g, x):
|
||||
return (x - step_size * g,)
|
||||
return (x - step_size(i) * g,)
|
||||
return init_fun, update_fun
|
||||
|
||||
@optimizer
|
||||
def momentum(step_size, mass):
|
||||
"""Init and update step functions for SGD with Nesterov momentum."""
|
||||
"""Construct init and update step functions for SGD with Nesterov momentum.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
v0 = np.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update_fun(i, g, x, velocity):
|
||||
velocity = mass * velocity - (1. - mass) * g
|
||||
x = x + step_size * velocity
|
||||
x = x + step_size(i) * velocity
|
||||
return x, velocity
|
||||
return init_fun, update_fun
|
||||
|
||||
@optimizer
|
||||
def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
"""Init and update step functions for RMSProp."""
|
||||
"""Construct init and update step functions for RMSProp.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
avg_sq_grad = np.ones_like(x0)
|
||||
return x0, avg_sq_grad
|
||||
def update_fun(i, g, x, avg_sq_grad):
|
||||
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
|
||||
x = x - step_size * g / (np.sqrt(avg_sq_grad) + eps)
|
||||
x = x - step_size(i) * g / (np.sqrt(avg_sq_grad) + eps)
|
||||
return x, avg_sq_grad
|
||||
return init_fun, update_fun
|
||||
|
||||
@optimizer
|
||||
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""Init and update step functions for Adam."""
|
||||
"""Construct init and update step functions for Adam.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
for the second moment estimates (default 0.999).
|
||||
eps: optional, a positive scalar value for epsilon, a small constant for
|
||||
numerical stability (default 1e-8).
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun) pair.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init_fun(x0):
|
||||
m0 = np.zeros_like(x0)
|
||||
v0 = np.zeros_like(x0)
|
||||
@ -91,26 +142,47 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
|
||||
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
|
||||
vhat = v / (1 - b2 ** (i + 1))
|
||||
x = x - step_size*mhat / (np.sqrt(vhat) + eps)
|
||||
x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
|
||||
return x, m, v
|
||||
return init_fun, update_fun
|
||||
|
||||
def run_optimizer(loss, infeed, update_fun, state):
|
||||
"""A convenience function for running optimizers with iterated map-reduce.
|
||||
# learning rate schedules
|
||||
|
||||
Args:
|
||||
loss: a scalar-valued loss function taking two aguments, the current iterate
|
||||
and a data value.
|
||||
infeed: an infeed instance supplying the data stream.
|
||||
update_fun: a function that has signature update_fun(i, grad, state) where
|
||||
i is the integer iteration count, grad is the gradient of the loss at the
|
||||
current iterate, and state is the current optimizer state.
|
||||
state: the initial optimizer state.
|
||||
def constant(step_size):
|
||||
def schedule(i):
|
||||
return step_size
|
||||
return schedule
|
||||
|
||||
Returns:
|
||||
A pair (x, state) where is the final iterate and state represents the final
|
||||
optimizer state.
|
||||
"""
|
||||
map_fun = lambda _, state, batch: grad(loss)(iterate(state), batch)
|
||||
state = fax.iterated_map_reduce(state, map_fun, update_fun, infeed)
|
||||
return iterate(state), state
|
||||
def exponential_decay(step_size, decay_steps, decay_rate):
|
||||
def schedule(i):
|
||||
return step_size * decay_rate ** (i / decay_steps)
|
||||
return schedule
|
||||
|
||||
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
|
||||
if staircase:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * np.floor(i / decay_steps))
|
||||
else:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * i / decay_steps)
|
||||
return schedule
|
||||
|
||||
def piecewise_constant(boundaries, values):
|
||||
boundaries = np.array(boundaries)
|
||||
values = np.array(values)
|
||||
if not boundaries.ndim == values.ndim == 1:
|
||||
raise ValueError("boundaries and values must be sequences")
|
||||
if not boundaries.shape[0] == values.shape[0] - 1:
|
||||
raise ValueError("boundaries length must be one longer than values length")
|
||||
|
||||
def schedule(i):
|
||||
return values[np.sum(i > boundaries)]
|
||||
return schedule
|
||||
|
||||
def make_schedule(constant_scalar_or_schedule_fun):
|
||||
if np.isscalar(constant_scalar_or_schedule_fun):
|
||||
return constant(constant_scalar_or_schedule_fun)
|
||||
elif callable(constant_scalar_or_schedule_fun):
|
||||
return constant_scalar_or_schedule_fun
|
||||
else:
|
||||
raise TypeError, type(constant_scalar_or_schedule_fun)
|
||||
|
@ -173,12 +173,12 @@ class JaxprTracerTuple(tuple): pass
|
||||
Destructuring = namedtuple('Destructuring', ['i', 'eqn', 'key'])
|
||||
|
||||
class PartialVal(tuple):
|
||||
def __init__(self, xs):
|
||||
def __new__(cls, xs):
|
||||
assert core.skip_checks or (
|
||||
isinstance(xs[0], valid_pv_types)
|
||||
and isinstance(xs[1], core.Tracer) or core.valid_jaxtype(xs[1])
|
||||
), xs
|
||||
super(PartialVal, self).__init__(xs)
|
||||
return tuple.__new__(cls, xs)
|
||||
|
||||
valid_pv_types = (AbstractValue, JaxprTracerTuple, type(None))
|
||||
|
||||
|
21
jax/lax.py
21
jax/lax.py
@ -1566,9 +1566,30 @@ def slice_transpose_rule(t, start_indices, limit_indices, strides,
|
||||
assert result.shape == operand_shape
|
||||
return [result]
|
||||
|
||||
def slice_batching_rule(batched_args, batch_dims, start_indices, limit_indices,
|
||||
strides, **unused_kwargs):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
|
||||
new_start_indices = list(start_indices)
|
||||
new_start_indices.insert(bdim, 0)
|
||||
|
||||
new_limit_indices = list(limit_indices)
|
||||
new_limit_indices.insert(bdim, operand.shape[bdim])
|
||||
|
||||
if strides is None:
|
||||
new_strides = None
|
||||
else:
|
||||
new_strides = list(strides)
|
||||
new_strides.insert(bdim, 1)
|
||||
|
||||
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
|
||||
return out, bdim
|
||||
|
||||
slice_p = standard_primitive(slice_shape_rule, _input_dtype, 'slice',
|
||||
slice_translation_rule)
|
||||
ad.deflinear(slice_p, slice_transpose_rule)
|
||||
batching.primitive_batchers[slice_p] = slice_batching_rule
|
||||
|
||||
|
||||
def dynamic_slice_shape_rule(operand, start_indices, slice_sizes,
|
||||
|
@ -171,7 +171,7 @@ def _constant_like(x, const):
|
||||
def _wraps(fun):
|
||||
"""Like functools.wraps but works with numpy.ufuncs."""
|
||||
docstr = """
|
||||
LAX-backed implementation of {fun}. Corresponding Numpy docstring below.
|
||||
LAX-backed implementation of {fun}. Original docstring below.
|
||||
|
||||
{np_doc}
|
||||
""".format(fun=fun.__name__, np_doc=fun.__doc__)
|
||||
|
@ -309,7 +309,7 @@ def shuffle(key, x, axis=0):
|
||||
for _ in range(num_rounds):
|
||||
key, subkey = split(key)
|
||||
sort_keys = _random_bits(subkey, 32, x.shape)
|
||||
_, x = lax.sort_keyval(sort_keys, x, axis)
|
||||
_, x = lax.sort_key_val(sort_keys, x, axis)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -20,22 +20,7 @@ import numpy as onp
|
||||
import scipy.misc as osp_misc
|
||||
|
||||
from .. import lax
|
||||
|
||||
|
||||
def _wraps(fun):
|
||||
"""Like functools.wraps but works with numpy.ufuncs."""
|
||||
docstr = """
|
||||
LAX-backed implementation of {fun}. Corresponding Scipy docstring below.
|
||||
|
||||
{np_doc}
|
||||
""".format(fun=fun.__name__, np_doc=fun.__doc__)
|
||||
def wrap(op):
|
||||
try:
|
||||
op.__name__ = fun.__name__
|
||||
op.__doc__ = docstr
|
||||
finally:
|
||||
return op
|
||||
return wrap
|
||||
from ..numpy.lax_numpy import _wraps, _reduction_dims, _constant_like
|
||||
|
||||
|
||||
@_wraps(osp_misc.logsumexp)
|
||||
@ -50,19 +35,3 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
|
||||
_constant_like(a, 0), lax.add, dims)), amax)
|
||||
return dimadd(out) if keepdims else out
|
||||
|
||||
|
||||
# TODO(mattjj): this is duplicated from lax_numpy.py
|
||||
def _reduction_dims(a, axis):
|
||||
if axis is None:
|
||||
return onp.arange(onp.ndim(a))
|
||||
elif isinstance(axis, (onp.ndarray, tuple, list)):
|
||||
return onp.mod(onp.asarray(axis), onp.ndim(a))
|
||||
elif isinstance(axis, int):
|
||||
return onp.mod([axis], onp.ndim(a))
|
||||
else:
|
||||
raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
|
||||
|
||||
|
||||
def _constant_like(x, const):
|
||||
return onp.array(const, dtype=lax._dtype(x))
|
||||
|
@ -19,22 +19,7 @@ from __future__ import print_function
|
||||
import scipy.special as osp_special
|
||||
|
||||
from .. import lax
|
||||
|
||||
|
||||
def _wraps(fun):
|
||||
"""Like functools.wraps but works with numpy.ufuncs."""
|
||||
docstr = """
|
||||
LAX-backed implementation of {fun}. Corresponding Scipy docstring below.
|
||||
|
||||
{np_doc}
|
||||
""".format(fun=fun.__name__, np_doc=fun.__doc__)
|
||||
def wrap(op):
|
||||
try:
|
||||
op.__name__ = fun.__name__
|
||||
op.__doc__ = docstr
|
||||
finally:
|
||||
return op
|
||||
return wrap
|
||||
from ..numpy.lax_numpy import _wraps
|
||||
|
||||
|
||||
gammaln = _wraps(osp_special.gammaln)(lax.lgamma)
|
||||
|
16
jax/scipy/stats/__init__.py
Normal file
16
jax/scipy/stats/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from . import norm
|
33
jax/scipy/stats/norm.py
Normal file
33
jax/scipy/stats/norm.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf)
|
||||
def logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.norm.logpdf, x, loc, scale)
|
||||
two = _constant_like(x, 2)
|
||||
scale_sqrd = lax.pow(scale, two)
|
||||
log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * onp.pi), scale_sqrd))
|
||||
quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
|
||||
return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||
import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax import lax
|
||||
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr
|
||||
from jax.api import vmap
|
||||
from jax.core import unit
|
||||
@ -135,6 +136,24 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
check_dtypes=False)
|
||||
self.assertEqual(len(side), 1)
|
||||
|
||||
def testSliceLax(self):
|
||||
fun = lambda x: lax.slice(x, (2,), (4,))
|
||||
R = onp.random.RandomState(0).randn
|
||||
x = R(5, 10)
|
||||
|
||||
ans = vmap(fun, x)
|
||||
expected_ans = x[:, 2:4]
|
||||
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
||||
|
||||
def testSliceNumpy(self):
|
||||
fun = lambda x: x[:, 2]
|
||||
R = onp.random.RandomState(0).randn
|
||||
x = R(10, 5, 3, 7)
|
||||
|
||||
ans = vmap(fun, x)
|
||||
expected_ans = x[:, :, 2]
|
||||
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -27,15 +27,17 @@ from absl.testing import parameterized
|
||||
import numpy as onp
|
||||
import scipy.misc as osp_misc
|
||||
import scipy.special as osp_special
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import api
|
||||
from jax import test_util as jtu
|
||||
from jax.scipy import misc as lsp_misc
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import stats as lsp_stats
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
|
||||
float_dtypes = [onp.float32, onp.float64]
|
||||
complex_dtypes = [onp.complex64]
|
||||
@ -79,6 +81,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
for axis in range(-len(shape), len(shape))
|
||||
for keepdims in [False, True])
|
||||
def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
|
||||
# TODO(mattjj): test autodiff
|
||||
def scipy_fun(array_to_reduce):
|
||||
return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
|
||||
|
||||
@ -101,6 +104,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
|
||||
def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, modes):
|
||||
# TODO(mattjj): unskip this test combination when real() on tpu is improved
|
||||
# TODO(mattjj): test autodiff
|
||||
if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
|
||||
and not shapes[0]):
|
||||
return absltest.unittest.skip("real() on scalar not supported on tpu")
|
||||
@ -111,7 +115,40 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)
|
||||
|
||||
# TODO(mattjj): add test for lsp_stats.norm_logpdf
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
"", shapes, dtypes),
|
||||
"rng": rng, "shapes": shapes, "dtypes": dtypes}
|
||||
for shapes in CombosWithReplacement(all_shapes, 3)
|
||||
for dtypes in CombosWithReplacement(default_dtypes, 3)
|
||||
for rng in [jtu.rand_default()])
|
||||
def testNormLogPdfThreeArgs(self, rng, shapes, dtypes):
|
||||
# TODO(mattjj): test autodiff
|
||||
scipy_fun = osp_stats.norm.logpdf
|
||||
lax_fun = lsp_stats.norm.logpdf
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
scale = 0.5 + onp.abs(scale)
|
||||
return [x, loc, scale]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
"", shapes, dtypes),
|
||||
"rng": rng, "shapes": shapes, "dtypes": dtypes}
|
||||
for shapes in CombosWithReplacement(all_shapes, 2)
|
||||
for dtypes in CombosWithReplacement(default_dtypes, 2)
|
||||
for rng in [jtu.rand_default()])
|
||||
def testNormLogPdfTwoArgs(self, rng, shapes, dtypes):
|
||||
# TODO(mattjj): test autodiff
|
||||
scale = 0.5
|
||||
scipy_fun = functools.partial(osp_stats.norm.logpdf, scale=scale)
|
||||
lax_fun = functools.partial(lsp_stats.norm.logpdf, scale=scale)
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -114,5 +114,41 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
partial_loss = functools.partial(loss, y)
|
||||
self._CheckRun(minmax.sgd, partial_loss, x0, num_iters, step_size)
|
||||
|
||||
def testSgdVectorExponentialDecaySchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = minmax.exponential_decay(0.1, 3, 2.)
|
||||
self._CheckOptimizer(minmax.sgd, loss, x0, num_iters, step_schedule)
|
||||
|
||||
def testSgdVectorInverseTimeDecaySchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = minmax.inverse_time_decay(0.1, 3, 2.)
|
||||
self._CheckOptimizer(minmax.sgd, loss, x0, num_iters, step_schedule)
|
||||
|
||||
def testAdamVectorInverseTimeDecaySchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = minmax.inverse_time_decay(0.1, 3, 2.)
|
||||
self._CheckOptimizer(minmax.adam, loss, x0, num_iters, step_schedule)
|
||||
|
||||
def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_sched = minmax.inverse_time_decay(0.1, 3, 2., staircase=True)
|
||||
mass = 0.9
|
||||
self._CheckOptimizer(minmax.momentum, loss, x0, num_iters, step_sched, mass)
|
||||
|
||||
def testRmspropVectorPiecewiseConstantSchedule(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_schedule = minmax.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
|
||||
self._CheckOptimizer(minmax.rmsprop, loss, x0, num_iters, step_schedule)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -127,6 +127,23 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64])
|
||||
def testShuffle(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
x = onp.arange(100).astype(dtype)
|
||||
rand = lambda key: random.shuffle(key, x)
|
||||
crand = api.jit(rand)
|
||||
|
||||
perm1 = rand(key)
|
||||
perm2 = crand(key)
|
||||
|
||||
self.assertTrue(onp.all(perm1 == perm2))
|
||||
self.assertTrue(onp.all(perm1.dtype == perm2.dtype))
|
||||
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
|
||||
self.assertTrue(onp.all(onp.sort(perm1) == x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user