diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index ab5bad835..2d9c43831 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -1,5 +1,7 @@ # Custom operations for GPUs with C++ and CUDA + + JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX. To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments. diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 56ac53baf..da95f96d8 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -15,6 +15,8 @@ kernelspec: (advanced-autodiff)= # Advanced automatic differentiation + + In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful. Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already. diff --git a/docs/_tutorials/advanced-compilation.md b/docs/_tutorials/advanced-compilation.md index a3aeeaf3c..09535f2fc 100644 --- a/docs/_tutorials/advanced-compilation.md +++ b/docs/_tutorials/advanced-compilation.md @@ -1,5 +1,7 @@ # Advanced compilation + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/advanced-debugging.md b/docs/_tutorials/advanced-debugging.md index 34d15e30b..56188e095 100644 --- a/docs/_tutorials/advanced-debugging.md +++ b/docs/_tutorials/advanced-debugging.md @@ -14,6 +14,9 @@ kernelspec: (advanced-debugging)= # Advanced debugging + + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/external-callbacks.md b/docs/_tutorials/external-callbacks.md index 0420afaaa..a46927e6a 100644 --- a/docs/_tutorials/external-callbacks.md +++ b/docs/_tutorials/external-callbacks.md @@ -22,6 +22,8 @@ kernelspec: (external-callbacks)= # External callbacks + + This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`. ## Why callbacks? diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/_tutorials/gradient-checkpointing.md index a15ee1941..b768514e4 100644 --- a/docs/_tutorials/gradient-checkpointing.md +++ b/docs/_tutorials/gradient-checkpointing.md @@ -15,6 +15,8 @@ kernelspec: (gradient-checkpointing)= ## Gradient checkpointing with `jax.checkpoint` (`jax.remat`) + + In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning. If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials. diff --git a/docs/_tutorials/jax-primitives.md b/docs/_tutorials/jax-primitives.md index e5fab275a..51abe2916 100644 --- a/docs/_tutorials/jax-primitives.md +++ b/docs/_tutorials/jax-primitives.md @@ -15,6 +15,8 @@ kernelspec: (jax-internals-jax-primitives)= # JAX Internals: primitives + + ## Introduction to JAX primitives A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). diff --git a/docs/_tutorials/jaxpr.md b/docs/_tutorials/jaxpr.md index 03c0bef08..9fe990c0a 100644 --- a/docs/_tutorials/jaxpr.md +++ b/docs/_tutorials/jaxpr.md @@ -15,6 +15,8 @@ kernelspec: (jax-internals-jaxpr)= # JAX internals: The jaxpr language + + Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF). Conceptually, one can think of JAX transformations, such as {func}`jax.jit` or {func}`jax.grad`, as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form that is then interpreted with transformation-specific interpretation rules. diff --git a/docs/_tutorials/parallelism.md b/docs/_tutorials/parallelism.md index 8bf695743..9b840357e 100644 --- a/docs/_tutorials/parallelism.md +++ b/docs/_tutorials/parallelism.md @@ -1,5 +1,7 @@ # Parallel computation + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/profiling-and-performance.md b/docs/_tutorials/profiling-and-performance.md index e540c920e..d9a13b213 100644 --- a/docs/_tutorials/profiling-and-performance.md +++ b/docs/_tutorials/profiling-and-performance.md @@ -1,5 +1,7 @@ # Profiling and performance + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/simple-neural-network.md b/docs/_tutorials/simple-neural-network.md index b5c91ffd0..76e98db88 100644 --- a/docs/_tutorials/simple-neural-network.md +++ b/docs/_tutorials/simple-neural-network.md @@ -1,5 +1,7 @@ # Example: Writing a simple neural network + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. ``` diff --git a/docs/aot.md b/docs/aot.md index 8615d7513..ed7f45749 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -2,6 +2,8 @@ # Ahead-of-time lowering and compilation + + JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning a function that is compiled and runs on accelerators or the CPU. As the JIT acronym indicates, all compilation happens _just-in-time_ for execution. diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 39c4386f4..b3019bfc1 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -2,6 +2,8 @@ # API compatibility + + JAX is constantly evolving, and we want to be able to make improvements to its APIs. That said, we want to minimize churn for the JAX user community, and we try to make breaking changes rarely. diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 24980cf30..11752f0b0 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -36,6 +36,8 @@ "source": [ "# Autodidax: JAX core from scratch\n", "\n", + "\n", + "\n", "Ever want to learn how JAX works, but the implementation seemed impenetrable?\n", "Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n", "JAX's core system. You'll even get clued into our weird jargon!\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index b5f82ec4f..07e997a3d 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -39,6 +39,8 @@ Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab. # Autodidax: JAX core from scratch + + Ever want to learn how JAX works, but the implementation seemed impenetrable? Well, you're in luck! By reading this tutorial, you'll learn every big idea in JAX's core system. You'll even get clued into our weird jargon! diff --git a/docs/autodidax.py b/docs/autodidax.py index 3a9f1f415..8a4f83fec 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -31,6 +31,8 @@ # # Autodidax: JAX core from scratch # +# +# # Ever want to learn how JAX works, but the implementation seemed impenetrable? # Well, you're in luck! By reading this tutorial, you'll learn every big idea in # JAX's core system. You'll even get clued into our weird jargon! diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md index 4a8922dab..cc4a19aab 100644 --- a/docs/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -15,6 +15,8 @@ kernelspec: (automatic-differentiation)= # Automatic differentiation + + In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as: diff --git a/docs/automatic-vectorization.md b/docs/automatic-vectorization.md index 794a9f113..7559155e2 100644 --- a/docs/automatic-vectorization.md +++ b/docs/automatic-vectorization.md @@ -15,6 +15,8 @@ kernelspec: (automatic-vectorization)= # Automatic vectorization + + In the previous section we discussed JIT compilation via the {func}`jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`. diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index f6b8e84cd..e0a440491 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -1,5 +1,7 @@ # Building on JAX + + A great way to learn advanced JAX usage is to see how other libraries are using JAX, both how they integrate the library into their API, what functionality it adds mathematically, diff --git a/docs/contributing.md b/docs/contributing.md index 5040fbd9f..2d1331bf2 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,5 +1,7 @@ # Contributing to JAX + + Everyone can contribute to JAX, and we value everyone's contributions. There are several ways to contribute, including: diff --git a/docs/debugging.md b/docs/debugging.md index b53f08139..7ee36f19f 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -15,6 +15,8 @@ kernelspec: (debugging)= # Introduction to debugging + + This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations. Let's begin with {func}`jax.debug.print`. diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index a804d3603..2dad9b863 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -1,5 +1,7 @@ # The `checkify` transformation + + **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 90a6cb3bb..1cf1829e5 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -1,5 +1,7 @@ # JAX debugging flags + + JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 35e0f6895..b00fcc13d 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -1,5 +1,7 @@ # Runtime value debugging in JAX + + Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index f29f68c4d..440cc38d9 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -1,5 +1,7 @@ # `jax.debug.print` and `jax.debug.breakpoint` + + The {mod}`jax.debug` package offers some useful tools for inspecting values inside of JIT-ted functions. diff --git a/docs/deprecation.md b/docs/deprecation.md index 5ee58882a..7a8b867b6 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -1,6 +1,8 @@ (version-support-policy)= # Python and NumPy version support policy + + For NumPy and SciPy version support, JAX follows the Python scientific community's [SPEC 0](https://scientific-python.org/specs/spec-0000/). diff --git a/docs/developer.md b/docs/developer.md index c936d1ba2..018982f4c 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,6 +1,8 @@ (building-from-source)= # Building from source + + First, obtain the JAX source code: ``` diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index a6f27e9e9..e4d871b78 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,5 +1,6 @@ # Device Memory Profiling + ```{note} May 2023 update: we recommend using [Tensorboard diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index 5f2ea8e61..70cbd26ba 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -14,6 +14,8 @@ kernelspec: # Distributed data loading in a multi-host/multi-process environment + + This high-level guide demonstrates how you can perform distributed data loading — when you run JAX in a {doc}`multi-host or multi-process environment <./multi_process>`, and the data required for the JAX computations is split across the multiple processes. This document covers the overall approach for how to think about distributed data loading, and then how to apply it to *data-parallel* (simpler) and *model-parallel* (more complicated) workloads. Distributed data loading is usually more efficient (the data is split across processes) but also *more complex* compared with its alternatives, such as: 1) loading the *full global data in a single process*, splitting it up and sending the needed parts to the other processes via RPC; and 2) loading the *full global data in all processes* and only using the needed parts in each process. Loading the full global data is often simpler but more expensive. For example, in machine learning the training loop can get blocked while waiting for data, and additional network bandwidth gets used per each process. diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 92d6933e1..40a5d2f0d 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,5 +1,7 @@ # GPU performance tips + + This document focuses on performance tips for neural network workloads ## Matmul precision diff --git a/docs/installation.md b/docs/installation.md index 8f894ae31..82a0fde31 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,6 +1,8 @@ (installation)= # Installing JAX + + Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators. diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index fb5293a82..4affae3a6 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -1,6 +1,8 @@ (investigating-a-regression)= # Investigating a regression + + So you updated JAX and you hit a speed regression? You have a little bit of time and are ready to investigate this? Let's first make a JAX issue. diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 860197bec..95d4a632a 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,6 +1,8 @@ (jax-array-migration)= # jax.Array migration + + **yashkatariya@** ## TL;DR diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index ddf99db90..2d442c841 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -22,6 +22,8 @@ kernelspec: (jit-compilation)= # Just-in-time compilation + + In this section, we will further explore how JAX works, and how we can make it performant. We will discuss the {func}`jax.jit` transformation, which will perform *Just In Time* (JIT) compilation of a JAX Python function so it can be executed efficiently in XLA. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index 90a491b6a..4b114c857 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -15,6 +15,8 @@ kernelspec: (key-concepts)= # Key Concepts + + This section briefly introduces some key concepts of the JAX package. (key-concepts-jax-arrays)= diff --git a/docs/multi_process.md b/docs/multi_process.md index 63e4709ed..7d7083bde 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,5 +1,7 @@ # Using JAX in multi-host and multi-process environments + + ## Introduction This guide explains how to use JAX in environments such as diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index b52dc2176..2665e25fd 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -8,6 +8,8 @@ "source": [ "# 🔪 JAX - The Sharp Bits 🔪\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" ] }, diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index a46a07b5f..58fcb4310 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -16,6 +16,8 @@ kernelspec: # 🔪 JAX - The Sharp Bits 🔪 + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +++ {"id": "4k5PVzEo2uJO"} diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 70f499a89..3abb6d9cb 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -8,6 +8,8 @@ "source": [ "# Custom derivative rules for JAX-transformable Python functions\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 02b9e0827..ad577d55c 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -15,6 +15,8 @@ kernelspec: # Custom derivative rules for JAX-transformable Python functions + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) *mattjj@ Mar 19 2020, last updated Oct 14 2020* diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 54eb00d78..2face1d4a 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -6,7 +6,9 @@ "id": "PxHrg4Cjuapm" }, "source": [ - "# Distributed arrays and automatic parallelization" + "# Distributed arrays and automatic parallelization\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index dad695d12..b9ec9dc69 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -15,6 +15,8 @@ kernelspec: # Distributed arrays and automatic parallelization + + +++ {"id": "pFtQjv4SzHRj"} [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index 89e774d84..f42e3f74b 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -8,6 +8,8 @@ "source": [ "# How JAX primitives work\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", "\n", "*necula@google.com*, October 2019.\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index 17a7379dd..0ebf202f2 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -15,6 +15,8 @@ kernelspec: # How JAX primitives work + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) *necula@google.com*, October 2019. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index b48ac353c..f0c157655 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -8,6 +8,8 @@ "source": [ "# Training a Simple Neural Network, with PyTorch Data Loading\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "\n", "**Copyright 2018 The JAX Authors.**\n", diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 8fb2d4f06..2c53bb1e4 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -16,6 +16,8 @@ kernelspec: # Training a Simple Neural Network, with PyTorch Data Loading + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) **Copyright 2018 The JAX Authors.** diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 1a1a77eb9..7e65aefe3 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -8,6 +8,8 @@ "source": [ "# Writing custom Jaxpr interpreters in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 866eeffe1..e52c6a5f8 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -16,6 +16,8 @@ kernelspec: # Writing custom Jaxpr interpreters in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +++ {"id": "r-3vMiKRYXPJ"} diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index d0b8fe0c0..edfd0d453 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -8,6 +8,8 @@ "source": [ "# The Autodiff Cookbook\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", "*alexbw@, mattjj@* \n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 0d08e5061..c24d05c0e 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -16,6 +16,8 @@ kernelspec: # The Autodiff Cookbook + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) *alexbw@, mattjj@* diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 9aec8b1a2..f0552e526 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -6,7 +6,9 @@ "id": "29WqUVkCXjDD" }, "source": [ - "## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)" + "## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 3b8d6218a..b31e093b6 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -15,6 +15,8 @@ kernelspec: ## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`) + + ```{code-cell} import jax import jax.numpy as jnp diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index c4ef1961b..0a8233530 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -8,6 +8,8 @@ "source": [ "# Generalized Convolutions in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 4216e0ffc..3de8f261a 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -16,6 +16,8 @@ kernelspec: # Generalized Convolutions in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) JAX provides a number of interfaces to compute convolutions across data, including: diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index 5cda80620..bdf71004c 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -6,7 +6,9 @@ "id": "7XNMxdTwURqI" }, "source": [ - "# External Callbacks in JAX" + "# External Callbacks in JAX\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index 582b3536e..857eef42e 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -15,6 +15,8 @@ kernelspec: # External Callbacks in JAX + + +++ {"id": "h6lXo6bSUYGq"} This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation. diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 8368fc3aa..95c00bf1e 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -38,6 +38,8 @@ "source": [ "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index e16c5ce25..8f795484d 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -36,6 +36,8 @@ limitations under the License. # Training a Simple Neural Network, with tensorflow/datasets Data Loading + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) _Forked from_ `neural_network_and_data_loading.ipynb` diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index e9fafa2c1..ed0a13d87 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -7,6 +7,8 @@ "source": [ "# SPMD multi-device parallelism with `shard_map`\n", "\n", + "\n", + "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 403463812..67494cfd4 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -16,6 +16,8 @@ kernelspec: # SPMD multi-device parallelism with `shard_map` + + `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. `shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index bcaa1f42b..1c1c9729b 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -8,6 +8,8 @@ "source": [ "# How to Think in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 1f25bdc4e..14089fa36 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -15,6 +15,8 @@ kernelspec: # How to Think in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 833c3a40d..96b334296 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -8,6 +8,8 @@ "source": [ "# Autobatching for Bayesian Inference\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index ac0864fb1..ea8b4fce2 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -16,6 +16,8 @@ kernelspec: # Autobatching for Bayesian Inference + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. diff --git a/docs/notebooks/xmap_tutorial.ipynb b/docs/notebooks/xmap_tutorial.ipynb index 4e216e588..a8eb76c35 100644 --- a/docs/notebooks/xmap_tutorial.ipynb +++ b/docs/notebooks/xmap_tutorial.ipynb @@ -8,6 +8,8 @@ "source": [ "# Named axes and easy-to-revise parallelism with `xmap`\n", "\n", + "\n", + "\n", "**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).\n", "\n", "This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.\n", diff --git a/docs/notebooks/xmap_tutorial.md b/docs/notebooks/xmap_tutorial.md index d45570788..c4b511dbe 100644 --- a/docs/notebooks/xmap_tutorial.md +++ b/docs/notebooks/xmap_tutorial.md @@ -15,6 +15,8 @@ kernelspec: # Named axes and easy-to-revise parallelism with `xmap` + + **_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer. diff --git a/docs/pallas/design.md b/docs/pallas/design.md index 4fae9a11d..991c2c1b3 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design.md @@ -1,5 +1,7 @@ # Pallas Design + + In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas's specific APIs might have changed since. ## Introduction diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 94673e753..19a3c4ef5 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -7,6 +7,8 @@ "source": [ "# Pallas Quickstart\n", "\n", + "\n", + "\n", "Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.\n", "\n", "Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 931d1e96f..2f4b3cc77 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -14,6 +14,8 @@ kernelspec: # Pallas Quickstart + + Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction. Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic. diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 87fa02ec1..c954ea3fc 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -6,7 +6,9 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining and `BlockSpec`s" + "# Pipelining and `BlockSpec`s\n", + "\n", + "" ] }, { diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 6acae60cf..13a18a356 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -15,6 +15,8 @@ kernelspec: # Pipelining and `BlockSpec`s + + +++ {"id": "gAJDZh1gBh-h"} In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute. diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 7fd0e81a9..2f748825a 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,5 +1,7 @@ # Persistent Compilation Cache + + JAX has an optional disk cache for compiled programs. If enabled, JAX will store copies of compiled programs on disk, which can save recompilation time when running the same or similar tasks repeatedly. diff --git a/docs/profiling.md b/docs/profiling.md index fe92b1b0e..6eceec8f5 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,5 +1,7 @@ # Profiling JAX programs + + ## Viewing program traces with Perfetto We can use the JAX profiler to generate traces of a JAX program that can be diff --git a/docs/pytrees.md b/docs/pytrees.md index 80860b1b8..a39c36db5 100644 --- a/docs/pytrees.md +++ b/docs/pytrees.md @@ -16,6 +16,8 @@ language_info: # Pytrees + + ## What is a pytree? In JAX, we use the term *pytree* to refer to a tree-like structure built out of diff --git a/docs/quickstart.md b/docs/quickstart.md index 6858a2c4e..91ac5a63b 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -14,6 +14,8 @@ kernelspec: # Quickstart + + **JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 4a88ed5cc..85bb5ce01 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -15,6 +15,8 @@ kernelspec: (pseudorandom-numbers)= # Pseudorandom numbers + + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index e6c16e2de..8fa210779 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -7,6 +7,8 @@ "(sharded-computation)=\n", "# Introduction to sharded computation\n", "\n", + "\n", + "\n", "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", "\n", "The tutorial covers three modes of parallel computation:\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 6a7dd36c2..345ca7987 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -14,6 +14,8 @@ kernelspec: (sharded-computation)= # Introduction to sharded computation + + This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs. The tutorial covers three modes of parallel computation: diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index b802be0e0..5a8af2b74 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -14,6 +14,8 @@ kernelspec: # Stateful Computations + + JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have no side effects such as updating of global state. diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md index 6521a9b85..2bd1cc08e 100644 --- a/docs/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -22,6 +22,8 @@ kernelspec: (working-with-pytrees)= # Working with pytrees + + JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns.