Add freshness metablock to JAX OSS docs.

PiperOrigin-RevId: 645508135
This commit is contained in:
jax authors 2024-06-21 14:50:02 -07:00 committed by jax authors
parent 694cafb72b
commit fc1e1d4a65
80 changed files with 164 additions and 4 deletions

View File

@ -1,5 +1,7 @@
# Custom operations for GPUs with C++ and CUDA
<!--* freshness: { reviewed: '2024-06-07' } *-->
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.

View File

@ -15,6 +15,8 @@ kernelspec:
(advanced-autodiff)=
# Advanced automatic differentiation
<!--* freshness: { reviewed: '2024-05-14' } *-->
In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.
Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already.

View File

@ -1,5 +1,7 @@
# Advanced compilation
<!--* freshness: { reviewed: '2024-05-03' } *-->
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

View File

@ -14,6 +14,9 @@ kernelspec:
(advanced-debugging)=
# Advanced debugging
<!--* freshness: { reviewed: '2024-05-03' } *-->
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

View File

@ -22,6 +22,8 @@ kernelspec:
(external-callbacks)=
# External callbacks
<!--* freshness: { reviewed: '2024-05-16' } *-->
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.
## Why callbacks?

View File

@ -15,6 +15,8 @@ kernelspec:
(gradient-checkpointing)=
## Gradient checkpointing with `jax.checkpoint` (`jax.remat`)
<!--* freshness: { reviewed: '2024-05-03' } *-->
In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning.
If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials.

View File

@ -15,6 +15,8 @@ kernelspec:
(jax-internals-jax-primitives)=
# JAX Internals: primitives
<!--* freshness: { reviewed: '2024-05-03' } *-->
## Introduction to JAX primitives
A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide).

View File

@ -15,6 +15,8 @@ kernelspec:
(jax-internals-jaxpr)=
# JAX internals: The jaxpr language
<!--* freshness: { reviewed: '2024-05-03' } *-->
Jaxprs are JAXs internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF).
Conceptually, one can think of JAX transformations, such as {func}`jax.jit` or {func}`jax.grad`, as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form that is then interpreted with transformation-specific interpretation rules.

View File

@ -1,5 +1,7 @@
# Parallel computation
<!--* freshness: { reviewed: '2024-05-03' } *-->
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

View File

@ -1,5 +1,7 @@
# Profiling and performance
<!--* freshness: { reviewed: '2024-05-03' } *-->
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

View File

@ -1,5 +1,7 @@
# Example: Writing a simple neural network
<!--* freshness: { reviewed: '2024-05-03' } *-->
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
```

View File

@ -2,6 +2,8 @@
# Ahead-of-time lowering and compilation
<!--* freshness: { reviewed: '2024-06-12' } *-->
JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
a function that is compiled and runs on accelerators or the CPU. As the JIT
acronym indicates, all compilation happens _just-in-time_ for execution.

View File

@ -2,6 +2,8 @@
# API compatibility
<!--* freshness: { reviewed: '2023-07-18' } *-->
JAX is constantly evolving, and we want to be able to make improvements to its
APIs. That said, we want to minimize churn for the JAX user community, and we
try to make breaking changes rarely.

View File

@ -36,6 +36,8 @@
"source": [
"# Autodidax: JAX core from scratch\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"Ever want to learn how JAX works, but the implementation seemed impenetrable?\n",
"Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n",
"JAX's core system. You'll even get clued into our weird jargon!\n",

View File

@ -39,6 +39,8 @@ Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.
# Autodidax: JAX core from scratch
<!--* freshness: { reviewed: '2024-04-08' } *-->
Ever want to learn how JAX works, but the implementation seemed impenetrable?
Well, you're in luck! By reading this tutorial, you'll learn every big idea in
JAX's core system. You'll even get clued into our weird jargon!

View File

@ -31,6 +31,8 @@
# # Autodidax: JAX core from scratch
#
# <!--* freshness: { reviewed: '2024-04-08' } *-->
#
# Ever want to learn how JAX works, but the implementation seemed impenetrable?
# Well, you're in luck! By reading this tutorial, you'll learn every big idea in
# JAX's core system. You'll even get clued into our weird jargon!

View File

@ -15,6 +15,8 @@ kernelspec:
(automatic-differentiation)=
# Automatic differentiation
<!--* freshness: { reviewed: '2024-05-03' } *-->
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system.
Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:

View File

@ -15,6 +15,8 @@ kernelspec:
(automatic-vectorization)=
# Automatic vectorization
<!--* freshness: { reviewed: '2024-05-03' } *-->
In the previous section we discussed JIT compilation via the {func}`jax.jit` function.
This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`.

View File

@ -1,5 +1,7 @@
# Building on JAX
<!--* freshness: { reviewed: '2024-05-03' } *-->
A great way to learn advanced JAX usage is to see how other libraries are using JAX,
both how they integrate the library into their API,
what functionality it adds mathematically,

View File

@ -1,5 +1,7 @@
# Contributing to JAX
<!--* freshness: { reviewed: '2023-11-16' } *-->
Everyone can contribute to JAX, and we value everyone's contributions. There are several
ways to contribute, including:

View File

@ -15,6 +15,8 @@ kernelspec:
(debugging)=
# Introduction to debugging
<!--* freshness: { reviewed: '2024-05-10' } *-->
This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
Let's begin with {func}`jax.debug.print`.

View File

@ -1,5 +1,7 @@
# The `checkify` transformation
<!--* freshness: { reviewed: '2023-02-28' } *-->
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python

View File

@ -1,5 +1,7 @@
# JAX debugging flags
<!--* freshness: { reviewed: '2024-04-11' } *-->
JAX offers flags and context managers that enable catching errors more easily.
## `jax_debug_nans` configuration option and context manager

View File

@ -1,5 +1,7 @@
# Runtime value debugging in JAX
<!--* freshness: { reviewed: '2024-04-11' } *-->
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
Table of contents:

View File

@ -1,5 +1,7 @@
# `jax.debug.print` and `jax.debug.breakpoint`
<!--* freshness: { reviewed: '2024-03-13' } *-->
The {mod}`jax.debug` package offers some useful tools for inspecting values
inside of JIT-ted functions.

View File

@ -1,6 +1,8 @@
(version-support-policy)=
# Python and NumPy version support policy
<!--* freshness: { reviewed: '2024-05-02' } *-->
For NumPy and SciPy version support, JAX follows the Python scientific community's
[SPEC 0](https://scientific-python.org/specs/spec-0000/).

View File

@ -1,6 +1,8 @@
(building-from-source)=
# Building from source
<!--* freshness: { reviewed: '2024-05-15' } *-->
First, obtain the JAX source code:
```

View File

@ -1,5 +1,6 @@
# Device Memory Profiling
<!--* freshness: { reviewed: '2024-03-08' } *-->
```{note}
May 2023 update: we recommend using [Tensorboard

View File

@ -14,6 +14,8 @@ kernelspec:
# Distributed data loading in a multi-host/multi-process environment
<!--* freshness: { reviewed: '2024-05-16' } *-->
This high-level guide demonstrates how you can perform distributed data loading — when you run JAX in a {doc}`multi-host or multi-process environment <./multi_process>`, and the data required for the JAX computations is split across the multiple processes. This document covers the overall approach for how to think about distributed data loading, and then how to apply it to *data-parallel* (simpler) and *model-parallel* (more complicated) workloads.
Distributed data loading is usually more efficient (the data is split across processes) but also *more complex* compared with its alternatives, such as: 1) loading the *full global data in a single process*, splitting it up and sending the needed parts to the other processes via RPC; and 2) loading the *full global data in all processes* and only using the needed parts in each process. Loading the full global data is often simpler but more expensive. For example, in machine learning the training loop can get blocked while waiting for data, and additional network bandwidth gets used per each process.

View File

@ -1,5 +1,7 @@
# GPU performance tips
<!--* freshness: { reviewed: '2024-06-10' } *-->
This document focuses on performance tips for neural network workloads
## Matmul precision

View File

@ -1,6 +1,8 @@
(installation)=
# Installing JAX
<!--* freshness: { reviewed: '2024-06-18' } *-->
Using JAX requires installing two packages: `jax`, which is pure Python and
cross-platform, and `jaxlib` which contains compiled binaries, and requires
different builds for different operating systems and accelerators.

View File

@ -1,6 +1,8 @@
(investigating-a-regression)=
# Investigating a regression
<!--* freshness: { reviewed: '2023-11-15' } *-->
So you updated JAX and you hit a speed regression?
You have a little bit of time and are ready to investigate this?
Let's first make a JAX issue.

View File

@ -1,6 +1,8 @@
(jax-array-migration)=
# jax.Array migration
<!--* freshness: { reviewed: '2023-03-17' } *-->
**yashkatariya@**
## TL;DR

View File

@ -22,6 +22,8 @@ kernelspec:
(jit-compilation)=
# Just-in-time compilation
<!--* freshness: { reviewed: '2024-05-03' } *-->
In this section, we will further explore how JAX works, and how we can make it performant.
We will discuss the {func}`jax.jit` transformation, which will perform *Just In Time* (JIT)
compilation of a JAX Python function so it can be executed efficiently in XLA.

View File

@ -15,6 +15,8 @@ kernelspec:
(key-concepts)=
# Key Concepts
<!--* freshness: { reviewed: '2024-05-03' } *-->
This section briefly introduces some key concepts of the JAX package.
(key-concepts-jax-arrays)=

View File

@ -1,5 +1,7 @@
# Using JAX in multi-host and multi-process environments
<!--* freshness: { reviewed: '2024-06-10' } *-->
## Introduction
This guide explains how to use JAX in environments such as

View File

@ -8,6 +8,8 @@
"source": [
"# 🔪 JAX - The Sharp Bits 🔪\n",
"\n",
"<!--* freshness: { reviewed: '2024-06-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
]
},

View File

@ -16,6 +16,8 @@ kernelspec:
# 🔪 JAX - The Sharp Bits 🔪
<!--* freshness: { reviewed: '2024-06-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+++ {"id": "4k5PVzEo2uJO"}

View File

@ -8,6 +8,8 @@
"source": [
"# Custom derivative rules for JAX-transformable Python functions\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n",
"*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n",

View File

@ -15,6 +15,8 @@ kernelspec:
# Custom derivative rules for JAX-transformable Python functions
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
*mattjj@ Mar 19 2020, last updated Oct 14 2020*

View File

@ -6,7 +6,9 @@
"id": "PxHrg4Cjuapm"
},
"source": [
"# Distributed arrays and automatic parallelization"
"# Distributed arrays and automatic parallelization\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-16' } *-->"
]
},
{

View File

@ -15,6 +15,8 @@ kernelspec:
# Distributed arrays and automatic parallelization
<!--* freshness: { reviewed: '2024-04-16' } *-->
+++ {"id": "pFtQjv4SzHRj"}
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)

View File

@ -8,6 +8,8 @@
"source": [
"# How JAX primitives work\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
"\n",
"*necula@google.com*, October 2019.\n",

View File

@ -15,6 +15,8 @@ kernelspec:
# How JAX primitives work
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
*necula@google.com*, October 2019.

View File

@ -8,6 +8,8 @@
"source": [
"# Training a Simple Neural Network, with PyTorch Data Loading\n",
"\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
"\n",
"**Copyright 2018 The JAX Authors.**\n",

View File

@ -16,6 +16,8 @@ kernelspec:
# Training a Simple Neural Network, with PyTorch Data Loading
<!--* freshness: { reviewed: '2024-05-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
**Copyright 2018 The JAX Authors.**

View File

@ -8,6 +8,8 @@
"source": [
"# Writing custom Jaxpr interpreters in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
]
},

View File

@ -16,6 +16,8 @@ kernelspec:
# Writing custom Jaxpr interpreters in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+++ {"id": "r-3vMiKRYXPJ"}

View File

@ -8,6 +8,8 @@
"source": [
"# The Autodiff Cookbook\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"\n",
"*alexbw@, mattjj@* \n",

View File

@ -16,6 +16,8 @@ kernelspec:
# The Autodiff Cookbook
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
*alexbw@, mattjj@*

View File

@ -6,7 +6,9 @@
"id": "29WqUVkCXjDD"
},
"source": [
"## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)"
"## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
]
},
{

View File

@ -15,6 +15,8 @@ kernelspec:
## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)
<!--* freshness: { reviewed: '2024-04-08' } *-->
```{code-cell}
import jax
import jax.numpy as jnp

View File

@ -8,6 +8,8 @@
"source": [
"# Generalized Convolutions in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
"\n",
"JAX provides a number of interfaces to compute convolutions across data, including:\n",

View File

@ -16,6 +16,8 @@ kernelspec:
# Generalized Convolutions in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)
JAX provides a number of interfaces to compute convolutions across data, including:

View File

@ -6,7 +6,9 @@
"id": "7XNMxdTwURqI"
},
"source": [
"# External Callbacks in JAX"
"# External Callbacks in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
]
},
{

View File

@ -15,6 +15,8 @@ kernelspec:
# External Callbacks in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->
+++ {"id": "h6lXo6bSUYGq"}
This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation.

View File

@ -38,6 +38,8 @@
"source": [
"# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n",
"\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
"\n",
"_Forked from_ `neural_network_and_data_loading.ipynb`\n",

View File

@ -36,6 +36,8 @@ limitations under the License.
# Training a Simple Neural Network, with tensorflow/datasets Data Loading
<!--* freshness: { reviewed: '2024-05-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
_Forked from_ `neural_network_and_data_loading.ipynb`

View File

@ -7,6 +7,8 @@
"source": [
"# SPMD multi-device parallelism with `shard_map`\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n",
"\n",
"`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n",

View File

@ -16,6 +16,8 @@ kernelspec:
# SPMD multi-device parallelism with `shard_map`
<!--* freshness: { reviewed: '2024-04-08' } *-->
`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.
`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.

View File

@ -8,6 +8,8 @@
"source": [
"# How to Think in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
"\n",
"JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."

View File

@ -15,6 +15,8 @@ kernelspec:
# How to Think in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.

View File

@ -8,6 +8,8 @@
"source": [
"# Autobatching for Bayesian Inference\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
"\n",
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",

View File

@ -16,6 +16,8 @@ kernelspec:
# Autobatching for Bayesian Inference
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

View File

@ -8,6 +8,8 @@
"source": [
"# Named axes and easy-to-revise parallelism with `xmap`\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why dont `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).\n",
"\n",
"This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.\n",

View File

@ -15,6 +15,8 @@ kernelspec:
# Named axes and easy-to-revise parallelism with `xmap`
<!--* freshness: { reviewed: '2024-04-08' } *-->
**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why dont `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).
This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.

View File

@ -1,5 +1,7 @@
# Pallas Design
<!--* freshness: { reviewed: '2024-04-15' } *-->
In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas's specific APIs might have changed since.
## Introduction

View File

@ -7,6 +7,8 @@
"source": [
"# Pallas Quickstart\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.\n",
"\n",
"Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n",

View File

@ -14,6 +14,8 @@ kernelspec:
# Pallas Quickstart
<!--* freshness: { reviewed: '2024-04-08' } *-->
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.
Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.

View File

@ -6,7 +6,9 @@
"id": "teoJ_fUwlu0l"
},
"source": [
"# Pipelining and `BlockSpec`s"
"# Pipelining and `BlockSpec`s\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
]
},
{

View File

@ -15,6 +15,8 @@ kernelspec:
# Pipelining and `BlockSpec`s
<!--* freshness: { reviewed: '2024-04-08' } *-->
+++ {"id": "gAJDZh1gBh-h"}
In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute.

View File

@ -1,5 +1,7 @@
# Persistent Compilation Cache
<!--* freshness: { reviewed: '2024-04-09' } *-->
JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.

View File

@ -1,5 +1,7 @@
# Profiling JAX programs
<!--* freshness: { reviewed: '2024-03-18' } *-->
## Viewing program traces with Perfetto
We can use the JAX profiler to generate traces of a JAX program that can be

View File

@ -16,6 +16,8 @@ language_info:
# Pytrees
<!--* freshness: { reviewed: '2024-03-13' } *-->
## What is a pytree?
In JAX, we use the term *pytree* to refer to a tree-like structure built out of

View File

@ -14,6 +14,8 @@ kernelspec:
# Quickstart
<!--* freshness: { reviewed: '2024-06-13' } *-->
**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:

View File

@ -15,6 +15,8 @@ kernelspec:
(pseudorandom-numbers)=
# Pseudorandom numbers
<!--* freshness: { reviewed: '2024-05-03' } *-->
In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.

View File

@ -7,6 +7,8 @@
"(sharded-computation)=\n",
"# Introduction to sharded computation\n",
"\n",
"<!--* freshness: { reviewed: '2024-05-10' } *-->\n",
"\n",
"This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n",
"\n",
"The tutorial covers three modes of parallel computation:\n",

View File

@ -14,6 +14,8 @@ kernelspec:
(sharded-computation)=
# Introduction to sharded computation
<!--* freshness: { reviewed: '2024-05-10' } *-->
This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.
The tutorial covers three modes of parallel computation:

View File

@ -14,6 +14,8 @@ kernelspec:
# Stateful Computations
<!--* freshness: { reviewed: '2024-05-03' } *-->
JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions
they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have
no side effects such as updating of global state.

View File

@ -22,6 +22,8 @@ kernelspec:
(working-with-pytrees)=
# Working with pytrees
<!--* freshness: { reviewed: '2024-05-03' } *-->
JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees.
This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns.