mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538
96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""A basic MNIST example using Numpy and JAX.
|
|
|
|
The primary aim here is simplicity and minimal dependencies.
|
|
"""
|
|
|
|
|
|
import time
|
|
|
|
import numpy.random as npr
|
|
|
|
from jax import jit, grad
|
|
from jax.scipy.special import logsumexp
|
|
import jax.numpy as jnp
|
|
from examples import datasets
|
|
|
|
|
|
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
|
return [(scale * rng.randn(m, n), scale * rng.randn(n))
|
|
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
|
|
|
|
def predict(params, inputs):
|
|
activations = inputs
|
|
for w, b in params[:-1]:
|
|
outputs = jnp.dot(activations, w) + b
|
|
activations = jnp.tanh(outputs)
|
|
|
|
final_w, final_b = params[-1]
|
|
logits = jnp.dot(activations, final_w) + final_b
|
|
return logits - logsumexp(logits, axis=1, keepdims=True)
|
|
|
|
def loss(params, batch):
|
|
inputs, targets = batch
|
|
preds = predict(params, inputs)
|
|
return -jnp.mean(jnp.sum(preds * targets, axis=1))
|
|
|
|
def accuracy(params, batch):
|
|
inputs, targets = batch
|
|
target_class = jnp.argmax(targets, axis=1)
|
|
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
|
return jnp.mean(predicted_class == target_class)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
layer_sizes = [784, 1024, 1024, 10]
|
|
param_scale = 0.1
|
|
step_size = 0.001
|
|
num_epochs = 10
|
|
batch_size = 128
|
|
|
|
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
|
num_train = train_images.shape[0]
|
|
num_complete_batches, leftover = divmod(num_train, batch_size)
|
|
num_batches = num_complete_batches + bool(leftover)
|
|
|
|
def data_stream():
|
|
rng = npr.RandomState(0)
|
|
while True:
|
|
perm = rng.permutation(num_train)
|
|
for i in range(num_batches):
|
|
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
|
|
yield train_images[batch_idx], train_labels[batch_idx]
|
|
batches = data_stream()
|
|
|
|
@jit
|
|
def update(params, batch):
|
|
grads = grad(loss)(params, batch)
|
|
return [(w - step_size * dw, b - step_size * db)
|
|
for (w, b), (dw, db) in zip(params, grads)]
|
|
|
|
params = init_random_params(param_scale, layer_sizes)
|
|
for epoch in range(num_epochs):
|
|
start_time = time.time()
|
|
for _ in range(num_batches):
|
|
params = update(params, next(batches))
|
|
epoch_time = time.time() - start_time
|
|
|
|
train_acc = accuracy(params, (train_images, train_labels))
|
|
test_acc = accuracy(params, (test_images, test_labels))
|
|
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
|
|
print(f"Training set accuracy {train_acc}")
|
|
print(f"Test set accuracy {test_acc}")
|