"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb)\n",
"**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.**\n",
"\n",
"With its updated version of [Autograd](https://github.com/hips/autograd), JAX\n",
"can automatically differentiate native Python and NumPy code. It can\n",
"differentiate through a large subset of Python’s features, including loops, ifs,\n",
"recursion, and closures, and it can even take derivatives of derivatives of\n",
"derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily\n",
"to any order.\n",
"\n",
"What’s new is that JAX uses\n",
"[XLA](https://www.tensorflow.org/xla)\n",
"to compile and run your NumPy code on accelerators, like GPUs and TPUs.\n",
"Compilation happens under the hood by default, with library calls getting\n",
"just-in-time compiled and executed. But JAX even lets you just-in-time compile\n",
"your own Python functions into XLA-optimized kernels using a one-function API.\n",
"Compilation and automatic differentiation can be composed arbitrarily, so you\n",
"can express sophisticated algorithms and get maximal performance without having\n",
"to leave Python."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "SY8mDvEvCGqk"
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FQ89jHCYfhpg"
},
"source": [
"## Multiplying Matrices"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xpy1dSgNqCP4"
},
"source": [
"We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX].\n",
"\n",
"[Common Gotchas in JAX]: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "u0nseKZNqOoH"
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"x = random.normal(key, (10,))\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hDJF0UPKnuqB"
},
"source": [
"Let's dive right in and multiply two big matrices."
"That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using {func}`~jax.device_put`."
"The output of {func}`~jax.device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of {func}`~jax.device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster."
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:\n",
"JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). Let's try that."
"In addition to evaluating numerical functions, we also want to transform them. One transformation is [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation). In JAX, just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the {func}`~jax.grad` function."
"Taking derivatives is as easy as calling {func}`~jax.grad`. {func}`~jax.grad` and {func}`~jax.jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:"
"For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:"
"JAX has one more transformation in its API that you might find useful: {func}`~jax.vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions by hand."
"We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions."