mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
This commit is contained in:
parent
694cafb72b
commit
fc1e1d4a65
@ -1,5 +1,7 @@
|
||||
# Custom operations for GPUs with C++ and CUDA
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-07' } *-->
|
||||
|
||||
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
|
||||
|
||||
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(advanced-autodiff)=
|
||||
# Advanced automatic differentiation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-14' } *-->
|
||||
|
||||
In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.
|
||||
|
||||
Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Advanced compilation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
|
||||
|
||||
|
@ -14,6 +14,9 @@ kernelspec:
|
||||
|
||||
(advanced-debugging)=
|
||||
# Advanced debugging
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
|
||||
|
||||
|
@ -22,6 +22,8 @@ kernelspec:
|
||||
(external-callbacks)=
|
||||
# External callbacks
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-16' } *-->
|
||||
|
||||
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.
|
||||
|
||||
## Why callbacks?
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(gradient-checkpointing)=
|
||||
## Gradient checkpointing with `jax.checkpoint` (`jax.remat`)
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning.
|
||||
|
||||
If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials.
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(jax-internals-jax-primitives)=
|
||||
# JAX Internals: primitives
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
## Introduction to JAX primitives
|
||||
|
||||
A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide).
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(jax-internals-jaxpr)=
|
||||
# JAX internals: The jaxpr language
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF).
|
||||
|
||||
Conceptually, one can think of JAX transformations, such as {func}`jax.jit` or {func}`jax.grad`, as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form that is then interpreted with transformation-specific interpretation rules.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Parallel computation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Profiling and performance
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Example: Writing a simple neural network
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
|
||||
```
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
# Ahead-of-time lowering and compilation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-12' } *-->
|
||||
|
||||
JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
|
||||
a function that is compiled and runs on accelerators or the CPU. As the JIT
|
||||
acronym indicates, all compilation happens _just-in-time_ for execution.
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
# API compatibility
|
||||
|
||||
<!--* freshness: { reviewed: '2023-07-18' } *-->
|
||||
|
||||
JAX is constantly evolving, and we want to be able to make improvements to its
|
||||
APIs. That said, we want to minimize churn for the JAX user community, and we
|
||||
try to make breaking changes rarely.
|
||||
|
@ -36,6 +36,8 @@
|
||||
"source": [
|
||||
"# Autodidax: JAX core from scratch\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"Ever want to learn how JAX works, but the implementation seemed impenetrable?\n",
|
||||
"Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n",
|
||||
"JAX's core system. You'll even get clued into our weird jargon!\n",
|
||||
|
@ -39,6 +39,8 @@ Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.
|
||||
|
||||
# Autodidax: JAX core from scratch
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
Ever want to learn how JAX works, but the implementation seemed impenetrable?
|
||||
Well, you're in luck! By reading this tutorial, you'll learn every big idea in
|
||||
JAX's core system. You'll even get clued into our weird jargon!
|
||||
|
@ -31,6 +31,8 @@
|
||||
|
||||
# # Autodidax: JAX core from scratch
|
||||
#
|
||||
# <!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
#
|
||||
# Ever want to learn how JAX works, but the implementation seemed impenetrable?
|
||||
# Well, you're in luck! By reading this tutorial, you'll learn every big idea in
|
||||
# JAX's core system. You'll even get clued into our weird jargon!
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(automatic-differentiation)=
|
||||
# Automatic differentiation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system.
|
||||
Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
|
||||
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(automatic-vectorization)=
|
||||
# Automatic vectorization
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
In the previous section we discussed JIT compilation via the {func}`jax.jit` function.
|
||||
This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`.
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Building on JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
A great way to learn advanced JAX usage is to see how other libraries are using JAX,
|
||||
both how they integrate the library into their API,
|
||||
what functionality it adds mathematically,
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Contributing to JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2023-11-16' } *-->
|
||||
|
||||
Everyone can contribute to JAX, and we value everyone's contributions. There are several
|
||||
ways to contribute, including:
|
||||
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(debugging)=
|
||||
# Introduction to debugging
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-10' } *-->
|
||||
|
||||
This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
|
||||
|
||||
Let's begin with {func}`jax.debug.print`.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# The `checkify` transformation
|
||||
|
||||
<!--* freshness: { reviewed: '2023-02-28' } *-->
|
||||
|
||||
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
|
||||
```python
|
||||
|
@ -1,5 +1,7 @@
|
||||
# JAX debugging flags
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-11' } *-->
|
||||
|
||||
JAX offers flags and context managers that enable catching errors more easily.
|
||||
|
||||
## `jax_debug_nans` configuration option and context manager
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Runtime value debugging in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-11' } *-->
|
||||
|
||||
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
|
||||
|
||||
Table of contents:
|
||||
|
@ -1,5 +1,7 @@
|
||||
# `jax.debug.print` and `jax.debug.breakpoint`
|
||||
|
||||
<!--* freshness: { reviewed: '2024-03-13' } *-->
|
||||
|
||||
The {mod}`jax.debug` package offers some useful tools for inspecting values
|
||||
inside of JIT-ted functions.
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
(version-support-policy)=
|
||||
# Python and NumPy version support policy
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-02' } *-->
|
||||
|
||||
For NumPy and SciPy version support, JAX follows the Python scientific community's
|
||||
[SPEC 0](https://scientific-python.org/specs/spec-0000/).
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
(building-from-source)=
|
||||
# Building from source
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-15' } *-->
|
||||
|
||||
First, obtain the JAX source code:
|
||||
|
||||
```
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Device Memory Profiling
|
||||
|
||||
<!--* freshness: { reviewed: '2024-03-08' } *-->
|
||||
|
||||
```{note}
|
||||
May 2023 update: we recommend using [Tensorboard
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
|
||||
# Distributed data loading in a multi-host/multi-process environment
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-16' } *-->
|
||||
|
||||
This high-level guide demonstrates how you can perform distributed data loading — when you run JAX in a {doc}`multi-host or multi-process environment <./multi_process>`, and the data required for the JAX computations is split across the multiple processes. This document covers the overall approach for how to think about distributed data loading, and then how to apply it to *data-parallel* (simpler) and *model-parallel* (more complicated) workloads.
|
||||
|
||||
Distributed data loading is usually more efficient (the data is split across processes) but also *more complex* compared with its alternatives, such as: 1) loading the *full global data in a single process*, splitting it up and sending the needed parts to the other processes via RPC; and 2) loading the *full global data in all processes* and only using the needed parts in each process. Loading the full global data is often simpler but more expensive. For example, in machine learning the training loop can get blocked while waiting for data, and additional network bandwidth gets used per each process.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# GPU performance tips
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-10' } *-->
|
||||
|
||||
This document focuses on performance tips for neural network workloads
|
||||
|
||||
## Matmul precision
|
||||
|
@ -1,6 +1,8 @@
|
||||
(installation)=
|
||||
# Installing JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-18' } *-->
|
||||
|
||||
Using JAX requires installing two packages: `jax`, which is pure Python and
|
||||
cross-platform, and `jaxlib` which contains compiled binaries, and requires
|
||||
different builds for different operating systems and accelerators.
|
||||
|
@ -1,6 +1,8 @@
|
||||
(investigating-a-regression)=
|
||||
# Investigating a regression
|
||||
|
||||
<!--* freshness: { reviewed: '2023-11-15' } *-->
|
||||
|
||||
So you updated JAX and you hit a speed regression?
|
||||
You have a little bit of time and are ready to investigate this?
|
||||
Let's first make a JAX issue.
|
||||
|
@ -1,6 +1,8 @@
|
||||
(jax-array-migration)=
|
||||
# jax.Array migration
|
||||
|
||||
<!--* freshness: { reviewed: '2023-03-17' } *-->
|
||||
|
||||
**yashkatariya@**
|
||||
|
||||
## TL;DR
|
||||
|
@ -22,6 +22,8 @@ kernelspec:
|
||||
(jit-compilation)=
|
||||
# Just-in-time compilation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
In this section, we will further explore how JAX works, and how we can make it performant.
|
||||
We will discuss the {func}`jax.jit` transformation, which will perform *Just In Time* (JIT)
|
||||
compilation of a JAX Python function so it can be executed efficiently in XLA.
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(key-concepts)=
|
||||
# Key Concepts
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
This section briefly introduces some key concepts of the JAX package.
|
||||
|
||||
(key-concepts-jax-arrays)=
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Using JAX in multi-host and multi-process environments
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-10' } *-->
|
||||
|
||||
## Introduction
|
||||
|
||||
This guide explains how to use JAX in environments such as
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# 🔪 JAX - The Sharp Bits 🔪\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-06-03' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
|
||||
]
|
||||
},
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# 🔪 JAX - The Sharp Bits 🔪
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-03' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
|
||||
|
||||
+++ {"id": "4k5PVzEo2uJO"}
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# Custom derivative rules for JAX-transformable Python functions\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
|
||||
"\n",
|
||||
"*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n",
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# Custom derivative rules for JAX-transformable Python functions
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
|
||||
|
||||
*mattjj@ Mar 19 2020, last updated Oct 14 2020*
|
||||
|
@ -6,7 +6,9 @@
|
||||
"id": "PxHrg4Cjuapm"
|
||||
},
|
||||
"source": [
|
||||
"# Distributed arrays and automatic parallelization"
|
||||
"# Distributed arrays and automatic parallelization\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-16' } *-->"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# Distributed arrays and automatic parallelization
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-16' } *-->
|
||||
|
||||
+++ {"id": "pFtQjv4SzHRj"}
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# How JAX primitives work\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
|
||||
"\n",
|
||||
"*necula@google.com*, October 2019.\n",
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# How JAX primitives work
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
|
||||
|
||||
*necula@google.com*, October 2019.
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# Training a Simple Neural Network, with PyTorch Data Loading\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
|
||||
"\n",
|
||||
"**Copyright 2018 The JAX Authors.**\n",
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# Training a Simple Neural Network, with PyTorch Data Loading
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
|
||||
|
||||
**Copyright 2018 The JAX Authors.**
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# Writing custom Jaxpr interpreters in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
|
||||
]
|
||||
},
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# Writing custom Jaxpr interpreters in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
|
||||
|
||||
+++ {"id": "r-3vMiKRYXPJ"}
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# The Autodiff Cookbook\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
|
||||
"\n",
|
||||
"*alexbw@, mattjj@* \n",
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# The Autodiff Cookbook
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
|
||||
|
||||
*alexbw@, mattjj@*
|
||||
|
@ -6,7 +6,9 @@
|
||||
"id": "29WqUVkCXjDD"
|
||||
},
|
||||
"source": [
|
||||
"## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)"
|
||||
"## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
```{code-cell}
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# Generalized Convolutions in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
|
||||
"\n",
|
||||
"JAX provides a number of interfaces to compute convolutions across data, including:\n",
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# Generalized Convolutions in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)
|
||||
|
||||
JAX provides a number of interfaces to compute convolutions across data, including:
|
||||
|
@ -6,7 +6,9 @@
|
||||
"id": "7XNMxdTwURqI"
|
||||
},
|
||||
"source": [
|
||||
"# External Callbacks in JAX"
|
||||
"# External Callbacks in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# External Callbacks in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
+++ {"id": "h6lXo6bSUYGq"}
|
||||
|
||||
This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation.
|
||||
|
@ -38,6 +38,8 @@
|
||||
"source": [
|
||||
"# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
|
||||
"\n",
|
||||
"_Forked from_ `neural_network_and_data_loading.ipynb`\n",
|
||||
|
@ -36,6 +36,8 @@ limitations under the License.
|
||||
|
||||
# Training a Simple Neural Network, with tensorflow/datasets Data Loading
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
|
||||
|
||||
_Forked from_ `neural_network_and_data_loading.ipynb`
|
||||
|
@ -7,6 +7,8 @@
|
||||
"source": [
|
||||
"# SPMD multi-device parallelism with `shard_map`\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n",
|
||||
"\n",
|
||||
"`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n",
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# SPMD multi-device parallelism with `shard_map`
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.
|
||||
|
||||
`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# How to Think in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
|
||||
"\n",
|
||||
"JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# How to Think in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
|
||||
|
||||
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# Autobatching for Bayesian Inference\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
|
||||
"\n",
|
||||
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",
|
||||
|
@ -16,6 +16,8 @@ kernelspec:
|
||||
|
||||
# Autobatching for Bayesian Inference
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
|
||||
|
||||
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
|
||||
|
@ -8,6 +8,8 @@
|
||||
"source": [
|
||||
"# Named axes and easy-to-revise parallelism with `xmap`\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).\n",
|
||||
"\n",
|
||||
"This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.\n",
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# Named axes and easy-to-revise parallelism with `xmap`
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).
|
||||
|
||||
This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Pallas Design
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-15' } *-->
|
||||
|
||||
In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas's specific APIs might have changed since.
|
||||
|
||||
## Introduction
|
||||
|
@ -7,6 +7,8 @@
|
||||
"source": [
|
||||
"# Pallas Quickstart\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
"Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.\n",
|
||||
"\n",
|
||||
"Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n",
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
|
||||
# Pallas Quickstart
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.
|
||||
|
||||
Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.
|
||||
|
@ -6,7 +6,9 @@
|
||||
"id": "teoJ_fUwlu0l"
|
||||
},
|
||||
"source": [
|
||||
"# Pipelining and `BlockSpec`s"
|
||||
"# Pipelining and `BlockSpec`s\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
|
||||
# Pipelining and `BlockSpec`s
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
+++ {"id": "gAJDZh1gBh-h"}
|
||||
|
||||
In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Persistent Compilation Cache
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-09' } *-->
|
||||
|
||||
JAX has an optional disk cache for compiled programs. If enabled, JAX will
|
||||
store copies of compiled programs on disk, which can save recompilation time
|
||||
when running the same or similar tasks repeatedly.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Profiling JAX programs
|
||||
|
||||
<!--* freshness: { reviewed: '2024-03-18' } *-->
|
||||
|
||||
## Viewing program traces with Perfetto
|
||||
|
||||
We can use the JAX profiler to generate traces of a JAX program that can be
|
||||
|
@ -16,6 +16,8 @@ language_info:
|
||||
|
||||
# Pytrees
|
||||
|
||||
<!--* freshness: { reviewed: '2024-03-13' } *-->
|
||||
|
||||
## What is a pytree?
|
||||
|
||||
In JAX, we use the term *pytree* to refer to a tree-like structure built out of
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
|
||||
# Quickstart
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-13' } *-->
|
||||
|
||||
**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
|
||||
|
||||
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
|
||||
|
@ -15,6 +15,8 @@ kernelspec:
|
||||
(pseudorandom-numbers)=
|
||||
# Pseudorandom numbers
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
|
||||
|
||||
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.
|
||||
|
@ -7,6 +7,8 @@
|
||||
"(sharded-computation)=\n",
|
||||
"# Introduction to sharded computation\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-05-10' } *-->\n",
|
||||
"\n",
|
||||
"This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n",
|
||||
"\n",
|
||||
"The tutorial covers three modes of parallel computation:\n",
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
(sharded-computation)=
|
||||
# Introduction to sharded computation
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-10' } *-->
|
||||
|
||||
This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.
|
||||
|
||||
The tutorial covers three modes of parallel computation:
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
|
||||
# Stateful Computations
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions
|
||||
they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have
|
||||
no side effects such as updating of global state.
|
||||
|
@ -22,6 +22,8 @@ kernelspec:
|
||||
(working-with-pytrees)=
|
||||
# Working with pytrees
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees.
|
||||
This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user