docs: sentence case page titles, section headings, some content

This commit is contained in:
Roy Frostig 2024-08-12 13:52:27 -07:00
parent 2644299f7e
commit b8f8b7b07f
18 changed files with 52 additions and 52 deletions

View File

@ -11,7 +11,7 @@ and how it's used for computational speedup in other libraries.
Below are examples of how JAX's features can be used to define accelerated
computation across numerous domains and software packages.
## Gradient Computation
## Gradient computation
Easy gradient calculation is a key feature of JAX.
In the [JaxOpt library](https://github.com/google/jaxopt) value and grad is directly utilized for users in multiple optimization algorithms in [its source code](https://github.com/google/jaxopt/blob/main/jaxopt/_src/base.py#LL87C30-L87C44).
@ -19,7 +19,7 @@ Similarly the same Dynamax Optax pairing mentioned above is an example of
gradients enabling estimation methods that were challenging historically
[Maximum Likelihood Expectation using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html).
## Computational Speedup on a Single Core across Multiple Devices
## Computational speedup on a single core across multiple devices
Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling.
The same compiled code can then be sent to a CPU device,
to a GPU or TPU device for additional speedup,
@ -28,7 +28,7 @@ This allows for a smooth workflow from development into production.
In Dynamax the computationally expensive portion of a Linear State Space Model solver has been [jitted](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/models.py#L579).
A more complex example comes from PyTensor which compiles a JAX function dynamically and then [jits the constructed function](https://github.com/pymc-devs/pytensor/blob/main/pytensor/link/jax/linker.py#L64).
## Single and Multi Computer Speedup Using Parallelization
## Single and multi computer speedup using parallelization
Another benefit of JAX is the simplicity of parallelizing computation using
`pmap` and `vmap` function calls or decorators.
In Dynamax state space models are parallelized with a [VMAP decorator](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/parallel_inference.py#L89)
@ -43,7 +43,7 @@ such as Neural Networks or State Space models or others,
or provide specific functionality such as optimization.
Here are more specific examples of each pattern.
### Direct Usage
### Direct usage
Jax can be directly imported and utilized to build models “from scratch” as shown across this website,
for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html)
or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html).
@ -51,7 +51,7 @@ This may be the best option if you are unable to find prebuilt code
for your particular challenge, or if you're looking to reduce the number
of dependencies in your codebase.
### Composable Domain Specific Libraries with JAX exposed
### Composable domain specific libraries with JAX exposed
Another common approach are packages that provide prebuilt functionality,
whether it be model definition, or computation of some type.
Combinations of these packages can then be mixed and matched for a full
@ -68,7 +68,7 @@ With Dynamax parameters can be estimated using
[Maximum Likelihood using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html)
or full Bayesian Posterior can be estimating using [MCMC from Blackjax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_hmc.html)
### JAX Totally Hidden from Users
### JAX totally hidden from users
Other libraries opt to completely wrap JAX in their model specific API.
An example is PyMC and [Pytensor](https://github.com/pymc-devs/pytensor),
in which a user may never “see” JAX directly

View File

@ -162,7 +162,7 @@ possible. The `git rebase -i` command might be useful to this end.
(linting-and-type-checking)=
### Linting and Type-checking
### Linting and type-checking
JAX uses [mypy](https://mypy.readthedocs.io/) and
[ruff](https://docs.astral.sh/ruff/) to statically test code quality; the
@ -186,7 +186,7 @@ fix the issues you can push new commits to your branch.
### Restricted test suite
Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This
Once your PR has been reviewed, a JAX maintainer will mark it as `pull ready`. This
will trigger a larger set of tests, including tests on GPU and TPU backends that are
not available via standard GitHub CI. Detailed results of these tests are not publicly
viewable, but the JAX maintainer assigned to your PR will communicate with you regarding

View File

@ -1,4 +1,4 @@
# Device Memory Profiling
# Device memory profiling
<!--* freshness: { reviewed: '2024-03-08' } *-->
@ -9,7 +9,7 @@ profile, open the `memory_viewer` tab of the Tensorboard profiler for more
detailed and understandable device memory usage.
```
The JAX Device Memory Profiler allows us to explore how and why JAX programs are
The JAX device memory profiler allows us to explore how and why JAX programs are
using GPU or TPU memory. For example, it can be used to:
* Figure out which arrays and executables are in GPU memory at a given time, or

View File

@ -28,7 +28,7 @@ JAX glossary of terms
able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`).
jaxpr
Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
See :ref:`understanding-jaxprs` for more discussion and examples.

View File

@ -23,7 +23,7 @@ Here is a suggested investigation strategy:
2. Hourly recompilation while keeping XLA and JAX in sync.
3. Final verification: maybe a manual check of a few commits (or a git bisect).
## Nightly investigation.
## Nightly investigation
This can be done by using [JAX-Toolbox nightly
containers](https://github.com/NVIDIA/JAX-Toolbox).
@ -128,7 +128,7 @@ investigate hourly between 8-24 and 8-26. There was a smaller slowdown
earlier, lets ignore it for this example. It would be only another
hourly investigation between those dates.
## Hourly investigation.
## Hourly investigation
This does a checkout of JAX and XLA at each hour between the 2 dates,
rebuilds everything and runs the test. The scripts are structured

View File

@ -164,7 +164,7 @@ before (with two input vars, one for each element of the input tuple)
Constant Vars
Constant vars
-------------
Some values in jaxprs are constants, in that their value does not depend on the

View File

@ -6,7 +6,7 @@
"id": "18AF5Ab4p6VL"
},
"source": [
"# Training a Simple Neural Network, with PyTorch Data Loading\n",
"# Training a simple neural network, with PyTorch data loading\n",
"\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n",
@ -261,7 +261,7 @@
"id": "umJJGZCC2oKl"
},
"source": [
"## Data Loading with PyTorch\n",
"## Data loading with PyTorch\n",
"\n",
"JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays."
]
@ -494,7 +494,7 @@
"id": "xxPd6Qw3Z98v"
},
"source": [
"## Training Loop"
"## Training loop"
]
},
{

View File

@ -14,7 +14,7 @@ kernelspec:
+++ {"id": "18AF5Ab4p6VL"}
# Training a Simple Neural Network, with PyTorch Data Loading
# Training a simple neural network, with PyTorch data loading
<!--* freshness: { reviewed: '2024-05-03' } *-->
@ -175,7 +175,7 @@ def update(params, x, y):
+++ {"id": "umJJGZCC2oKl"}
## Data Loading with PyTorch
## Data loading with PyTorch
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays.
@ -245,7 +245,7 @@ test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
+++ {"id": "xxPd6Qw3Z98v"}
## Training Loop
## Training loop
```{code-cell} ipython3
:id: X2DnZo3iYj18

View File

@ -6,7 +6,7 @@
"id": "TVT_MVvc02AA"
},
"source": [
"# Generalized Convolutions in JAX\n",
"# Generalized convolutions in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
@ -28,7 +28,7 @@
"id": "ewZEn2X12-Ng"
},
"source": [
"## Basic One-dimensional Convolution\n",
"## Basic one-dimensional convolution\n",
"\n",
"Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:"
]
@ -91,7 +91,7 @@
"id": "5ndvLDIH4rv6"
},
"source": [
"## Basic N-dimensional Convolution\n",
"## Basic N-dimensional convolution\n",
"\n",
"For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.\n",
"\n",
@ -160,7 +160,7 @@
"id": "bxuUjFVG-v1h"
},
"source": [
"## General Convolutions"
"## General convolutions"
]
},
{

View File

@ -14,7 +14,7 @@ kernelspec:
+++ {"id": "TVT_MVvc02AA"}
# Generalized Convolutions in JAX
# Generalized convolutions in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->
@ -31,7 +31,7 @@ For basic convolution operations, the `jax.numpy` and `jax.scipy` operations are
+++ {"id": "ewZEn2X12-Ng"}
## Basic One-dimensional Convolution
## Basic one-dimensional convolution
Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:
@ -65,7 +65,7 @@ For more information, see the {func}`jax.numpy.convolve` documentation, or the d
+++ {"id": "5ndvLDIH4rv6"}
## Basic N-dimensional Convolution
## Basic N-dimensional convolution
For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.
@ -105,7 +105,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l
+++ {"id": "bxuUjFVG-v1h"}
## General Convolutions
## General convolutions
+++ {"id": "0pcn2LeS-03b"}

View File

@ -6,7 +6,7 @@
"id": "7XNMxdTwURqI"
},
"source": [
"# External Callbacks in JAX\n",
"# External callbacks in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
]

View File

@ -13,7 +13,7 @@ kernelspec:
+++ {"id": "7XNMxdTwURqI"}
# External Callbacks in JAX
# External callbacks in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->

View File

@ -36,7 +36,7 @@
"id": "B_XlLLpcWjkA"
},
"source": [
"# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n",
"# Training a simple neural network, with tensorflow/datasets data loading\n",
"\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n",
@ -274,7 +274,7 @@
"id": "umJJGZCC2oKl"
},
"source": [
"## Data Loading with `tensorflow/datasets`\n",
"## Data loading with `tensorflow/datasets`\n",
"\n",
"JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader."
]
@ -344,7 +344,7 @@
"id": "xxPd6Qw3Z98v"
},
"source": [
"## Training Loop"
"## Training loop"
]
},
{

View File

@ -34,7 +34,7 @@ limitations under the License.
+++ {"id": "B_XlLLpcWjkA"}
# Training a Simple Neural Network, with tensorflow/datasets Data Loading
# Training a simple neural network, with tensorflow/datasets data loading
<!--* freshness: { reviewed: '2024-05-03' } *-->
@ -183,7 +183,7 @@ def update(params, x, y):
+++ {"id": "umJJGZCC2oKl"}
## Data Loading with `tensorflow/datasets`
## Data loading with `tensorflow/datasets`
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader.
@ -229,7 +229,7 @@ print('Test:', test_images.shape, test_labels.shape)
+++ {"id": "xxPd6Qw3Z98v"}
## Training Loop
## Training loop
```{code-cell} ipython3
:id: X2DnZo3iYj18

View File

@ -6,7 +6,7 @@
"id": "LQHmwePqryRU"
},
"source": [
"# How to Think in JAX\n",
"# How to think in JAX\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
@ -23,7 +23,7 @@
"source": [
"## JAX vs. NumPy\n",
"\n",
"**Key Concepts:**\n",
"**Key concepts:**\n",
"\n",
"- JAX provides a NumPy-inspired interface for convenience.\n",
"- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n",
@ -282,7 +282,7 @@
"source": [
"## NumPy, lax & XLA: JAX API layering\n",
"\n",
"**Key Concepts:**\n",
"**Key concepts:**\n",
"\n",
"- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n",
"- `jax.lax` is a lower-level API that is stricter and often more powerful.\n",
@ -475,7 +475,7 @@
"source": [
"## To JIT or not to JIT\n",
"\n",
"**Key Concepts:**\n",
"**Key concepts:**\n",
"\n",
"- By default JAX executes operations one at a time, in sequence.\n",
"- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n",
@ -675,7 +675,7 @@
"source": [
"## JIT mechanics: tracing and static variables\n",
"\n",
"**Key Concepts:**\n",
"**Key concepts:**\n",
"\n",
"- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n",
"\n",
@ -932,9 +932,9 @@
"id": "r-RCl_wD5lI7"
},
"source": [
"## Static vs Traced Operations\n",
"## Static vs traced operations\n",
"\n",
"**Key Concepts:**\n",
"**Key concepts:**\n",
"\n",
"- Just as values can be either static or traced, operations can be static or traced.\n",
"\n",

View File

@ -13,7 +13,7 @@ kernelspec:
+++ {"id": "LQHmwePqryRU"}
# How to Think in JAX
# How to think in JAX
<!--* freshness: { reviewed: '2024-04-08' } *-->
@ -25,7 +25,7 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b
## JAX vs. NumPy
**Key Concepts:**
**Key concepts:**
- JAX provides a NumPy-inspired interface for convenience.
- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
@ -132,7 +132,7 @@ print(y)
## NumPy, lax & XLA: JAX API layering
**Key Concepts:**
**Key concepts:**
- `jax.numpy` is a high-level wrapper that provides a familiar interface.
- `jax.lax` is a lower-level API that is stricter and often more powerful.
@ -215,7 +215,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op
## To JIT or not to JIT
**Key Concepts:**
**Key concepts:**
- By default JAX executes operations one at a time, in sequence.
- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
@ -308,7 +308,7 @@ This is because the function generates an array whose shape is not known at comp
## JIT mechanics: tracing and static variables
**Key Concepts:**
**Key concepts:**
- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.
@ -417,9 +417,9 @@ Understanding which values and operations will be static and which will be trace
+++ {"id": "r-RCl_wD5lI7"}
## Static vs Traced Operations
## Static vs traced operations
**Key Concepts:**
**Key concepts:**
- Just as values can be either static or traced, operations can be static or traced.

View File

@ -1,4 +1,4 @@
# Persistent Compilation Cache
# Persistent compilation cache
<!--* freshness: { reviewed: '2024-04-09' } *-->

View File

@ -12,7 +12,7 @@ kernelspec:
name: python3
---
# Stateful Computations
# Stateful computations
<!--* freshness: { reviewed: '2024-05-03' } *-->