mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
update notebooks b/c jax comes with colab now!
This commit is contained in:
parent
be4dc0eb78
commit
b672ad72c8
@ -122,7 +122,7 @@ PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37
|
||||
CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100
|
||||
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
|
||||
BASE_URL='https://storage.googleapis.com/jax-releases'
|
||||
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.23-$PYTHON_VERSION-none-$PLATFORM.whl
|
||||
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.27-$PYTHON_VERSION-none-$PLATFORM.whl
|
||||
|
||||
pip install --upgrade jax # install jax
|
||||
```
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -117,20 +117,6 @@
|
||||
"### Setup and imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "2NXj3Dp5270W",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.23-cp36-none-linux_x86_64.whl\n",
|
||||
"!pip install --upgrade -q jax"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
@ -603,7 +589,7 @@
|
||||
"b, a = center(jnp.arange(3))\n",
|
||||
"print(np.array(a), np.array(b))"
|
||||
],
|
||||
"execution_count": 14,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -629,7 +615,7 @@
|
||||
"X = jnp.arange(12).reshape((3, 4))\n",
|
||||
"X"
|
||||
],
|
||||
"execution_count": 18,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -662,7 +648,7 @@
|
||||
"b, a = center(X, axis=1)\n",
|
||||
"print(np.array(a), np.array(b))"
|
||||
],
|
||||
"execution_count": 19,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -690,7 +676,7 @@
|
||||
"b, a = center(X, axis=0)\n",
|
||||
"print(np.array(a), np.array(b))"
|
||||
],
|
||||
"execution_count": 20,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -719,7 +705,7 @@
|
||||
"b, a = center.__wrapped__(X)\n",
|
||||
"print(np.array(a), np.array(b))"
|
||||
],
|
||||
"execution_count": 21,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -733,4 +719,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -55,20 +55,6 @@
|
||||
"- extending MAML to handle batching at the task-level\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "PaW85yP_BrCF",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.23-cp36-none-linux_x86_64.whl\n",
|
||||
"!pip install --upgrade -q jax"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
@ -118,7 +104,7 @@
|
||||
"print(grad(grad(g))(2.)) # x = 2\n",
|
||||
"print(grad(grad(grad(g)))(2.)) # x = 0"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -225,7 +211,7 @@
|
||||
"plt.plot(xrange_inputs, targets, label='target')\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
"execution_count": 7,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -314,7 +300,7 @@
|
||||
"plt.plot(xrange_inputs, targets, label='target')\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
"execution_count": 10,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -384,7 +370,7 @@
|
||||
"print('maml_objective(x,y)={}'.format(maml_objective(x0, y0))) # x**2 + 1 = 5\n",
|
||||
"print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0))) # x - (2x)"
|
||||
],
|
||||
"execution_count": 11,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -450,7 +436,7 @@
|
||||
"y2 = np.array([0.])\n",
|
||||
"maml_loss(net_params, x1, y1, x2, y2)"
|
||||
],
|
||||
"execution_count": 13,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -519,7 +505,7 @@
|
||||
" print(i)\n",
|
||||
"net_params = get_params(opt_state)"
|
||||
],
|
||||
"execution_count": 14,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -576,7 +562,7 @@
|
||||
" plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
"execution_count": 15,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -670,7 +656,7 @@
|
||||
" plt.scatter(x2[i], y2[i], label='task{}-val'.format(i))\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
"execution_count": 17,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -712,7 +698,7 @@
|
||||
"source": [
|
||||
"x2.shape"
|
||||
],
|
||||
"execution_count": 18,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -767,7 +753,7 @@
|
||||
" print(i)\n",
|
||||
"net_params = get_params(opt_state)"
|
||||
],
|
||||
"execution_count": 20,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -824,7 +810,7 @@
|
||||
" plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
"execution_count": 21,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -870,7 +856,7 @@
|
||||
"plt.ylim(0., 1e-1)\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
"execution_count": 22,
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
@ -912,4 +898,4 @@
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -17,11 +17,11 @@
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "18AF5Ab4p6VL",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"##### Copyright 2018 Google LLC.\n",
|
||||
"\n",
|
||||
@ -29,11 +29,11 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "crfqaJOyp8bq",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
@ -49,16 +49,14 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "B_XlLLpcWjkA",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Training a Simple Neural Network, with PyTorch Data Loading\n",
|
||||
"\n",
|
||||
"_Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary_\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
|
||||
@ -67,26 +65,12 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "-8OFzj9TqXof",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.23-cp36-none-linux_x86_64.whl\n",
|
||||
"!pip install --upgrade -q jax"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "OksHydJDtbbI",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from __future__ import print_function, division, absolute_import\n",
|
||||
"import jax.numpy as np\n",
|
||||
@ -97,23 +81,23 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "MTVcKi-ZYB3R",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Hyperparameters\n",
|
||||
"Let's get a few bookkeeping items out of the way."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "-fmWA06xYE7d",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# A helper function to randomly initialize weights and biases\n",
|
||||
"# for a dense neural network layer\n",
|
||||
@ -138,11 +122,11 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "BtoNk_yxWtIw",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Auto-batching predictions\n",
|
||||
"\n",
|
||||
@ -150,12 +134,12 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "7APc6tD7TiuZ",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from jax.scipy.special import logsumexp\n",
|
||||
"\n",
|
||||
@ -177,22 +161,22 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dRW_TvCTWgaP",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Let's check that our prediction function only works on single images."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "4sW2A5mnXHc5",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# This works on single examples\n",
|
||||
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
|
||||
@ -203,12 +187,12 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "PpyQxuedXfhp",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Doesn't work with a batch\n",
|
||||
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
|
||||
@ -221,12 +205,12 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "oJOOncKMXbwK",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let's upgrade it to handle batches using `vmap`\n",
|
||||
"\n",
|
||||
@ -241,32 +225,32 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "elsG6nX03BvW",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"At this point, we have all the ingredients we need to define our neural network and train it. We've built an auto-batched version of `predict`, which we should be able to use in a loss function. We should be able to use `grad` to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use `jit` to speed up everything."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "NwDuFqc9X7ER",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Utility and loss functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "6lTI6I4lWdh5",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def one_hot(x, k, dtype=np.float32):\n",
|
||||
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
|
||||
@ -291,11 +275,11 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "umJJGZCC2oKl",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Data Loading with PyTorch\n",
|
||||
"\n",
|
||||
@ -303,12 +287,12 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "gEvWt8_u2pqG",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install torch torchvision"
|
||||
],
|
||||
@ -316,13 +300,13 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "94PjXZ8y3dVF",
|
||||
"colab_type": "code",
|
||||
"cellView": "both",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"from torch.utils import data\n",
|
||||
@ -363,12 +347,12 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "l314jsfP4TN4",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Define our dataset, using torch datasets\n",
|
||||
"mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())\n",
|
||||
@ -378,12 +362,12 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "FTNo4beUvb6t",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Get the full train dataset (for checking accuracy while training)\n",
|
||||
"train_images = onp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)\n",
|
||||
@ -398,22 +382,22 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xxPd6Qw3Z98v",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Training Loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "X2DnZo3iYj18",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
@ -434,15 +418,15 @@
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xC1CMcVNYwxm",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We've now used the whole of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization.\n",
|
||||
"We used NumPy to specify all of our computation, and borrowed the great data loaders from PyTorch, and ran the whole thing on the GPU."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -57,7 +57,6 @@
|
||||
},
|
||||
"source": [
|
||||
"# JAX Quickstart\n",
|
||||
"Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -81,20 +80,6 @@
|
||||
"to leave Python.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "PaW85yP_BrCF",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.23-cp36-none-linux_x86_64.whl\n",
|
||||
"!pip install --upgrade -q jax"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
@ -178,7 +163,9 @@
|
||||
"id": "0AlN7EbonyaR"
|
||||
},
|
||||
"source": [
|
||||
"JAX NumPy functions work on regular NumPy arrays."
|
||||
"We added that `block_until_ready` because [JAX uses asynchronous execution by default](https://jax.readthedocs.io/en/latest/async_dispatch.html).\n",
|
||||
"\n",
|
||||
"JAX NumPy functions work on regular NumPy arrays. "
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -230,7 +217,7 @@
|
||||
"id": "clO9djnen8qi"
|
||||
},
|
||||
"source": [
|
||||
"The output of `device_put` still acts like an NDArray."
|
||||
"The output of `device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of `device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -592,4 +579,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user