Merge pull request #488 from levskaya/master

Common Gotchas notebook and notebook fixes
This commit is contained in:
Matthew Johnson 2019-03-24 07:41:22 -07:00 committed by GitHub
commit 9711202c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 2479 additions and 2 deletions

File diff suppressed because one or more lines are too long

32
notebooks/README.md Normal file
View File

@ -0,0 +1,32 @@
# Notebooks
Use the links below to open any of these for interactive exploration in colab.
- [Quick Start][quickstart] - the first notebook to go through, explores the basic JAX API.
- [Common Gotchas in JAX][Common_Gotchas_in_JAX] - answers for the most common problems people have while getting used to JAX's way of doing things.
- [MAML][maml] - pedagogical demonstration of Model-Agnostic Meta-Learning in JAX.
- [vmapped log-probabilities][vmapped log-probs] - demonstrates the utility of __vmap__ for Bayesian inference.
- [gufuncs via vmap][gufuncs] - how to implement NumPy-like gufuncs using __vmap__.
- [Neural Networks with TFDS Data][neural_network_with_tfds_data] - training a simple neural net with [tensorflow datasets][tfds].
- [Neural Networks and Data Loading][neural_network_and_data_loading] - training a simple neural net using a pytorch dataloader.
- [XLA in Python][XLA_in_Python] - interactive exploration of the XLA compiler and computation model in python.
[quickstart]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb
[Common_Gotchas_in_JAX]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb
[gufuncs]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/gufuncs.ipynb
[maml]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/maml.ipynb
[vmapped log-probs]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/vmapped%20log-probs.ipynb
[neural_network_with_tfds_data]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/neural_network_with_tfds_data.ipynb
[neural_network_and_data_loading]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb
[XLA_in_Python]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/XLA_in_Python.ipynb
[tfds]:https://github.com/tensorflow/datasets

View File

@ -28,8 +28,7 @@
"\n",
"## What is a gufunc?\n",
"\n",
"[Generalized universal functions](https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html) (\"gufuncs\") are one of my favorite abstractions from NumPy. They generalize NumPy's
[broadcasting rules](https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html) to handle non-scalar operations. When a gufuncs is applied to arrays, there are:\n",
"[Generalized universal functions](https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html) (\"gufuncs\") are one of my favorite abstractions from NumPy. They generalize NumPy's [broadcasting rules](https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html) to handle non-scalar operations. When a gufuncs is applied to arrays, there are:\n",
"- \"core dimensions\" over which an operation is defined.\n",
"- \"broadcast dimensions\" over which operations can be automatically vectorized.\n",
"\n",

View File

@ -26,6 +26,20 @@
"- extending MAML to handle batching at the task-level\n"
]
},
{
"metadata": {
"colab_type": "code",
"id": "PaW85yP_BrCF",
"colab": {}
},
"cell_type": "code",
"source": [
"!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.11-cp36-none-linux_x86_64.whl\n",
"!pip install --upgrade -q jax"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 1,

View File

@ -14,6 +14,20 @@
"Inspired by a notebook by @davmre."
]
},
{
"metadata": {
"colab_type": "code",
"id": "PaW85yP_BrCF",
"colab": {}
},
"cell_type": "code",
"source": [
"!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.11-cp36-none-linux_x86_64.whl\n",
"!pip install --upgrade -q jax"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 1,