mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #488 from levskaya/master
Common Gotchas notebook and notebook fixes
This commit is contained in:
commit
9711202c18
2418
notebooks/Common_Gotchas_in_JAX.ipynb
Normal file
2418
notebooks/Common_Gotchas_in_JAX.ipynb
Normal file
File diff suppressed because one or more lines are too long
32
notebooks/README.md
Normal file
32
notebooks/README.md
Normal file
@ -0,0 +1,32 @@
|
||||
# Notebooks
|
||||
|
||||
Use the links below to open any of these for interactive exploration in colab.
|
||||
|
||||
- [Quick Start][quickstart] - the first notebook to go through, explores the basic JAX API.
|
||||
|
||||
- [Common Gotchas in JAX][Common_Gotchas_in_JAX] - answers for the most common problems people have while getting used to JAX's way of doing things.
|
||||
|
||||
- [MAML][maml] - pedagogical demonstration of Model-Agnostic Meta-Learning in JAX.
|
||||
|
||||
- [vmapped log-probabilities][vmapped log-probs] - demonstrates the utility of __vmap__ for Bayesian inference.
|
||||
|
||||
- [gufuncs via vmap][gufuncs] - how to implement NumPy-like gufuncs using __vmap__.
|
||||
|
||||
- [Neural Networks with TFDS Data][neural_network_with_tfds_data] - training a simple neural net with [tensorflow datasets][tfds].
|
||||
|
||||
- [Neural Networks and Data Loading][neural_network_and_data_loading] - training a simple neural net using a pytorch dataloader.
|
||||
|
||||
- [XLA in Python][XLA_in_Python] - interactive exploration of the XLA compiler and computation model in python.
|
||||
|
||||
|
||||
|
||||
[quickstart]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb
|
||||
[Common_Gotchas_in_JAX]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb
|
||||
[gufuncs]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/gufuncs.ipynb
|
||||
[maml]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/maml.ipynb
|
||||
[vmapped log-probs]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/vmapped%20log-probs.ipynb
|
||||
[neural_network_with_tfds_data]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/neural_network_with_tfds_data.ipynb
|
||||
[neural_network_and_data_loading]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb
|
||||
[XLA_in_Python]:https://colab.sandbox.google.com/github/google/jax/blob/master/notebooks/XLA_in_Python.ipynb
|
||||
[tfds]:https://github.com/tensorflow/datasets
|
||||
|
@ -28,8 +28,7 @@
|
||||
"\n",
|
||||
"## What is a gufunc?\n",
|
||||
"\n",
|
||||
"[Generalized universal functions](https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html) (\"gufuncs\") are one of my favorite abstractions from NumPy. They generalize NumPy's
|
||||
[broadcasting rules](https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html) to handle non-scalar operations. When a gufuncs is applied to arrays, there are:\n",
|
||||
"[Generalized universal functions](https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html) (\"gufuncs\") are one of my favorite abstractions from NumPy. They generalize NumPy's [broadcasting rules](https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html) to handle non-scalar operations. When a gufuncs is applied to arrays, there are:\n",
|
||||
"- \"core dimensions\" over which an operation is defined.\n",
|
||||
"- \"broadcast dimensions\" over which operations can be automatically vectorized.\n",
|
||||
"\n",
|
||||
|
@ -26,6 +26,20 @@
|
||||
"- extending MAML to handle batching at the task-level\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "PaW85yP_BrCF",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.11-cp36-none-linux_x86_64.whl\n",
|
||||
"!pip install --upgrade -q jax"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
|
@ -14,6 +14,20 @@
|
||||
"Inspired by a notebook by @davmre."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"colab_type": "code",
|
||||
"id": "PaW85yP_BrCF",
|
||||
"colab": {}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.11-cp36-none-linux_x86_64.whl\n",
|
||||
"!pip install --upgrade -q jax"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
|
Loading…
x
Reference in New Issue
Block a user