mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
tweak docstrings in mnist examples
This commit is contained in:
parent
5d902298b9
commit
92e5f93a29
38
examples/fluidsim.py
Normal file
38
examples/fluidsim.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
def project(vx, vy):
|
||||
h = 1. / vx.shape[0]
|
||||
div = -0.5 * h * (np.roll(vx, -1, axis=0) - np.roll(vx, 1, axis=0) +
|
||||
np.roll(vy, -1, axis=1) - np.roll(vy, 1, axis=1))
|
||||
|
||||
p = np.zeros(vx.shape)
|
||||
for k in range(10):
|
||||
p = (div + np.roll(p, 1, axis=0) + np.roll(p, -1, axis=0)
|
||||
+ np.roll(p, 1, axis=1) + np.roll(p, -1, axis=1)) / 4.
|
||||
|
||||
vx = vx - 0.5 * (np.roll(p, -1, axis=0) - np.roll(p, 1, axis=0)) / h
|
||||
vy = vy - 0.5 * (np.roll(p, -1, axis=1) - np.roll(p, 1, axis=1)) / h
|
||||
return vx, vy
|
||||
|
||||
def advect(f, vx, vy):
|
||||
rows, cols = f.shape
|
||||
cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols))
|
||||
return linear_interpolate(f, cell_xs - vx, cell_ys - vy)
|
||||
|
||||
def linear_interpolate(
|
98
examples/mnist_bench.py
Normal file
98
examples/mnist_bench.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A basic MNIST example using JAX together with the mini-libraries stax, for
|
||||
neural network building, and optimizers, for first-order stochastic optimization.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import itertools
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
init_random_params, predict = stax.serial(
|
||||
Dense(1024), Relu,
|
||||
Dense(1024), Relu,
|
||||
Dense(10), LogSoftmax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
rng = random.PRNGKey(0)
|
||||
|
||||
step_size = 0.001
|
||||
num_epochs = 3
|
||||
batch_size = 128
|
||||
momentum_mass = 0.9
|
||||
|
||||
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
||||
num_train = train_images.shape[0]
|
||||
num_complete_batches, leftover = divmod(num_train, batch_size)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
|
||||
def data_stream():
|
||||
rng = npr.RandomState(0)
|
||||
while True:
|
||||
perm = rng.permutation(num_train)
|
||||
for i in range(num_batches):
|
||||
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
|
||||
yield train_images[batch_idx], train_labels[batch_idx]
|
||||
batches = data_stream()
|
||||
|
||||
opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)
|
||||
|
||||
@jit
|
||||
def update(i, opt_state, batch):
|
||||
params = optimizers.get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
_, init_params = init_random_params(rng, (-1, 28 * 28))
|
||||
opt_state = opt_init(init_params)
|
||||
itercount = itertools.count()
|
||||
|
||||
print("\nStarting training...")
|
||||
for epoch in range(num_epochs):
|
||||
start_time = time.time()
|
||||
for _ in range(num_batches):
|
||||
opt_state = update(next(itercount), opt_state, next(batches))
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
params = optimizers.get_params(opt_state)
|
||||
train_acc = accuracy(params, (train_images, train_labels))
|
||||
test_acc = accuracy(params, (test_images, test_labels))
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
@ -12,8 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A basic MNIST example using JAX together with the mini-libraries stax, for
|
||||
neural network building, and optimizers, for first-order stochastic optimization.
|
||||
"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.
|
||||
|
||||
The mini-library jax.experimental.stax is for neural network building, and
|
||||
the mini-library jax.experimentaloptimizers is for first-order stochastic
|
||||
optimization.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -12,9 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A basic MNIST example using Numpy and JAX.
|
||||
"""An MNIST example with single-program multiple-data (SPMD) data parallelism.
|
||||
|
||||
The primary aim here is simplicity and minimal dependencies.
|
||||
The aim here is to illustrate how to use JAX's `pmap` to express and execute
|
||||
SPMD programs for data parallelism along a batch dimension, while also
|
||||
minimizing dependencies by avoiding the use of higher-level layers and
|
||||
optimizers libraries.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
56
examples/spmd_spatially_sharded_conv_net.py
Normal file
56
examples/spmd_spatially_sharded_conv_net.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""An MNIST example for single-program multiple-data (SPMD) spatial parallelism.
|
||||
|
||||
The aim here is to illustrate how to use JAX's `pmap` to express and execute
|
||||
SPMD programs for data parallelism along a spatial dimension (rather than a
|
||||
batch dimension).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import itertools
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Relu, LogSoftmax
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
init_random_params, predict = stax.serial(
|
||||
SpmdConv(3, (2, 2), axis_name="x"), Relu,
|
||||
SpmdConv(3, (2, 2), axis_name="x"), Relu,
|
||||
SpmdDense(10), LogSoftmax)
|
Loading…
x
Reference in New Issue
Block a user