mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Adding quickstart notebook, and corresponding gitignore rules
This commit is contained in:
parent
2597300c7f
commit
47ade41368
8
.gitignore
vendored
8
.gitignore
vendored
@ -4,3 +4,11 @@
|
||||
/jax/lib/xla_client.py
|
||||
/jax/lib/xla_data_pb2.py
|
||||
jax.egg-info
|
||||
|
||||
.ipynb_checkpoints
|
||||
|
||||
/bazel-*
|
||||
.bazelrc
|
||||
/tensorflow
|
||||
|
||||
.DS_Store
|
||||
|
572
notebooks/quickstart.ipynb
Normal file
572
notebooks/quickstart.ipynb
Normal file
@ -0,0 +1,572 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 224
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "PaW85yP_BrCF",
|
||||
"outputId": "16fe6dee-a773-4889-fdb3-ebe75f48abe5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install https://storage.googleapis.com/jax-wheels/cuda92/jax-0.0-py3-none-any.whl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "xtWX4x9DCF5_"
|
||||
},
|
||||
"source": [
|
||||
"# JAX Quickstart\n",
|
||||
"dougalm@, phawkins@, mattjj@, frostig@, alexbw@\n",
|
||||
"\n",
|
||||
"### TODO: LOGO\n",
|
||||
"\n",
|
||||
"#### [JAX](http://go/jax) is NumPy on the CPU, GPU and TPU, with great automatic differentiation for high-performance machine learning research.\n",
|
||||
"\n",
|
||||
"With its updated version of [Autograd](https://github.com/hips/autograd), JAX\n",
|
||||
"can automatically differentiate native Python and NumPy code. It can\n",
|
||||
"differentiate through a large subset of Python’s features, including loops, ifs,\n",
|
||||
"recursion, and closures, and it can even take derivatives of derivatives of\n",
|
||||
"derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily\n",
|
||||
"to any order.\n",
|
||||
"\n",
|
||||
"What’s new is that JAX uses\n",
|
||||
"[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)\n",
|
||||
"to compile and run your NumPy code on accelerators, like GPUs and TPUs.\n",
|
||||
"Compilation happens under the hood by default, with library calls getting\n",
|
||||
"just-in-time compiled and executed. But JAX even lets you just-in-time compile\n",
|
||||
"your own Python functions into XLA-optimized kernels using a one-function API.\n",
|
||||
"Compilation and automatic differentiation can be composed arbitrarily, so you\n",
|
||||
"can express sophisticated algorithms and get maximal performance without having\n",
|
||||
"to leave Python.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "roHm4XGpf3rl"
|
||||
},
|
||||
"source": [
|
||||
"## The basics of JAX"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "SY8mDvEvCGqk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import print_function, division\n",
|
||||
"import numpy as onp\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import jax.numpy as np\n",
|
||||
"from jax import grad, jit, vmap\n",
|
||||
"from jax import device_put\n",
|
||||
"from jax import random"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "FQ89jHCYfhpg"
|
||||
},
|
||||
"source": [
|
||||
"### Multiplying Matrices"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "Xpy1dSgNqCP4"
|
||||
},
|
||||
"source": [
|
||||
"We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you ask for random numbers. We needed to make this change to support some of the great features we talk about below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "u0nseKZNqOoH"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"x = random.normal(key, (10,))\n",
|
||||
"print(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "hDJF0UPKnuqB"
|
||||
},
|
||||
"source": [
|
||||
"Let's dive right in and multiply two big matrices."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "eXn8GUl6CG5N"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"size = 3000\n",
|
||||
"x = random.normal(key, (size, size), dtype=np.float32)\n",
|
||||
"%timeit np.dot(x, x.T) # runs on the GPU"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "0AlN7EbonyaR"
|
||||
},
|
||||
"source": [
|
||||
"JAX NumPy functions work on raw NumPy arrays as well."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ZPl0MuwYrM7t"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as onp # original CPU-backed NumPy\n",
|
||||
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
|
||||
"%timeit np.dot(x, x.T)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "_SrcB2IurUuE"
|
||||
},
|
||||
"source": [
|
||||
"That's slower beacuse it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using `device_put`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Jj7M7zyRskF0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
|
||||
"x = device_put(x)\n",
|
||||
"%timeit np.dot(x, x.T)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "clO9djnen8qi"
|
||||
},
|
||||
"source": [
|
||||
"The output of `device_put` still acts like an NDArray. By the way, the implementation of `device_put` is just `device_put = jit(lambda x: x)`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "ghkfKNQttDpg"
|
||||
},
|
||||
"source": [
|
||||
"All of these calls above are faster than original NumPy on the CPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "RzXK8GnIs7VV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
|
||||
"%timeit onp.dot(x, x.T)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "iOzp0P_GoJhb"
|
||||
},
|
||||
"source": [
|
||||
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numeric code. For now, there's three main ones:\n",
|
||||
"\n",
|
||||
" - `jit`, for speeding up your code\n",
|
||||
" - `grad`, for taking derivatives\n",
|
||||
" - `vmap`, for automatic vectorization or batching.\n",
|
||||
"\n",
|
||||
"Let's go over these, one-by-one. We'll also end up composing these in interesting ways."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "bTTrTbWvgLUK"
|
||||
},
|
||||
"source": [
|
||||
"### Using `jit` to speed up functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "YrqE32mvE3b7"
|
||||
},
|
||||
"source": [
|
||||
"JAX runs transparently on the GPU (or CPU, if you don't have one, and TPU coming soon!). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, JAX will incur overhead. Fortunately, JAX has a `@jit` decorator which will fuse multiple operations together. Let's try that."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "qLGdCtFKFLOR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def selu_raw(x, alpha=1.67, lmbda=1.05):\n",
|
||||
" return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)\n",
|
||||
"\n",
|
||||
"x = np.zeros(1000000)\n",
|
||||
"%timeit selu_raw(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "a_V8SruVHrD_"
|
||||
},
|
||||
"source": [
|
||||
"We can speed it up with @jit, which will jit-compile the first time `selu` is called and will be cached thereafter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "fh4w_3NpFYTp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"selu = jit(selu)\n",
|
||||
"%timeit selu(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "HxpBc4WmfsEU"
|
||||
},
|
||||
"source": [
|
||||
"### Taking derivatives with `grad`\n",
|
||||
"\n",
|
||||
"We don't just want to compute with NumPy arrays, we also want to tranform numeric programs, like by taking their derivative. In JAX, just like in Autograd, there is a one-function API for taking derivatives: the `grad` function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "IMAgNJaMJwPD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def sum_logistic(x):\n",
|
||||
" return np.sum(1.0 / (1.0 + np.exp(-x)))\n",
|
||||
"\n",
|
||||
"x_small = np.ones((3,))\n",
|
||||
"derivative_fn = grad(sum_logistic)\n",
|
||||
"print(derivative_fn(x_small))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PtNs881Ohioc"
|
||||
},
|
||||
"source": [
|
||||
"Let's verify with finite differences that our result is correct."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "JXI7_OZuKZVO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def first_finite_differences(f, x):\n",
|
||||
" eps = 1e-3\n",
|
||||
" return np.array([(f(x + eps * basis_vect) - f(x - eps * basis_vect)) / (2 * eps)\n",
|
||||
" for basis_vect in onp.eye(len(x))])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(first_finite_differences(sum_logistic, x_small))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "Q2CUZjOWNZ-3"
|
||||
},
|
||||
"source": [
|
||||
"Taking derivatives is as easy as calling `grad`. `grad` and `jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "TO4g8ny-OEi4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "yCJ5feKvhnBJ"
|
||||
},
|
||||
"source": [
|
||||
"For more advanced autodiff, you can use `jax.vjp` for reverse-mode vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. We used them with `vmap` (which we'll describe in a moment) to write `jax.jacfwd` and `jax.jacrev` for computing full Jacobian matrices. Here's one way to compose those to make a function that efficiently computes full Hessian matrices:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Z-JxbiNyhxEW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax import jacfwd, jacrev\n",
|
||||
"def hessian(fun):\n",
|
||||
" return jacfwd(jacrev(fun))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "TI4nPsGafxbL"
|
||||
},
|
||||
"source": [
|
||||
"### Auto-vectorization with `vmap`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PcxkONy5aius"
|
||||
},
|
||||
"source": [
|
||||
"JAX has one more transformation in its API that you might find useful: `vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with `jit`, it can be just as fast as adding the batch dimensions by hand."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "TPiX4y-bWLFS"
|
||||
},
|
||||
"source": [
|
||||
"We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap. Although this is trivial to do by hand in this specific case, the same technique can apply to more complicated functions."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "cCQf0AAAa3ta"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def apply_matrix(v):\n",
|
||||
" return np.dot(mat, v)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "8w0Gpsn8WYYj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mat = random.normal(key, (150, 100))\n",
|
||||
"batched_x = random.normal(key, (10, 100))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "KWVc9BsZv0Ki"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def naively_batched_apply_matrix(v_batched):\n",
|
||||
" return np.stack([apply_matrix(v) for v in v_batched])\n",
|
||||
"\n",
|
||||
"print('Naively batched')\n",
|
||||
"%timeit naively_batched_apply_matrix(batched_x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ipei6l8nvrzH"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def batched_apply_matrix(v_batched):\n",
|
||||
" return np.dot(v_batched, mat.T)\n",
|
||||
"\n",
|
||||
"print('Manually batched')\n",
|
||||
"%timeit batched_apply_matrix(batched_x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "67Oeknf5vuCl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def vmap_batched_apply_matrix(batched_x):\n",
|
||||
" return vmap(apply_matrix, batched_x)\n",
|
||||
"\n",
|
||||
"print('Auto-vectorized with vmap')\n",
|
||||
"%timeit vmap_batched_apply_matrix(batched_x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "pYVl3Z2nbZhO"
|
||||
},
|
||||
"source": [
|
||||
"Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other JAX transformation. Now, let's put it all together.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "xC1CMcVNYwxm"
|
||||
},
|
||||
"source": [
|
||||
"We've now used the whole of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization.\n",
|
||||
"We used NumPy to specify all of our computation, and borrowed the great data loaders from PyTorch, and ran the whole thing on the GPU."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "JAX Quickstart.ipynb",
|
||||
"provenance": [],
|
||||
"toc_visible": true,
|
||||
"version": "0.3.2"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user