mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538
94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Datasets used in examples."""
|
|
|
|
|
|
import array
|
|
import gzip
|
|
import os
|
|
from os import path
|
|
import struct
|
|
import urllib.request
|
|
|
|
import numpy as np
|
|
|
|
|
|
_DATA = "/tmp/jax_example_data/"
|
|
|
|
|
|
def _download(url, filename):
|
|
"""Download a url to a file in the JAX data temp directory."""
|
|
if not path.exists(_DATA):
|
|
os.makedirs(_DATA)
|
|
out_file = path.join(_DATA, filename)
|
|
if not path.isfile(out_file):
|
|
urllib.request.urlretrieve(url, out_file)
|
|
print(f"downloaded {url} to {_DATA}")
|
|
|
|
|
|
def _partial_flatten(x):
|
|
"""Flatten all but the first dimension of an ndarray."""
|
|
return np.reshape(x, (x.shape[0], -1))
|
|
|
|
|
|
def _one_hot(x, k, dtype=np.float32):
|
|
"""Create a one-hot encoding of x of size k."""
|
|
return np.array(x[:, None] == np.arange(k), dtype)
|
|
|
|
|
|
def mnist_raw():
|
|
"""Download and parse the raw MNIST dataset."""
|
|
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
|
|
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
|
|
|
def parse_labels(filename):
|
|
with gzip.open(filename, "rb") as fh:
|
|
_ = struct.unpack(">II", fh.read(8))
|
|
return np.array(array.array("B", fh.read()), dtype=np.uint8)
|
|
|
|
def parse_images(filename):
|
|
with gzip.open(filename, "rb") as fh:
|
|
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
|
|
return np.array(array.array("B", fh.read()),
|
|
dtype=np.uint8).reshape(num_data, rows, cols)
|
|
|
|
for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
|
|
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
|
|
_download(base_url + filename, filename)
|
|
|
|
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
|
|
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
|
|
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
|
|
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
|
|
|
|
return train_images, train_labels, test_images, test_labels
|
|
|
|
|
|
def mnist(permute_train=False):
|
|
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
|
|
train_images, train_labels, test_images, test_labels = mnist_raw()
|
|
|
|
train_images = _partial_flatten(train_images) / np.float32(255.)
|
|
test_images = _partial_flatten(test_images) / np.float32(255.)
|
|
train_labels = _one_hot(train_labels, 10)
|
|
test_labels = _one_hot(test_labels, 10)
|
|
|
|
if permute_train:
|
|
perm = np.random.RandomState(0).permutation(train_images.shape[0])
|
|
train_images = train_images[perm]
|
|
train_labels = train_labels[perm]
|
|
|
|
return train_images, train_labels, test_images, test_labels
|