mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
docs: sentence case page titles, section headings, some content
This commit is contained in:
parent
2644299f7e
commit
b8f8b7b07f
@ -11,7 +11,7 @@ and how it's used for computational speedup in other libraries.
|
||||
Below are examples of how JAX's features can be used to define accelerated
|
||||
computation across numerous domains and software packages.
|
||||
|
||||
## Gradient Computation
|
||||
## Gradient computation
|
||||
Easy gradient calculation is a key feature of JAX.
|
||||
In the [JaxOpt library](https://github.com/google/jaxopt) value and grad is directly utilized for users in multiple optimization algorithms in [its source code](https://github.com/google/jaxopt/blob/main/jaxopt/_src/base.py#LL87C30-L87C44).
|
||||
|
||||
@ -19,7 +19,7 @@ Similarly the same Dynamax Optax pairing mentioned above is an example of
|
||||
gradients enabling estimation methods that were challenging historically
|
||||
[Maximum Likelihood Expectation using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html).
|
||||
|
||||
## Computational Speedup on a Single Core across Multiple Devices
|
||||
## Computational speedup on a single core across multiple devices
|
||||
Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling.
|
||||
The same compiled code can then be sent to a CPU device,
|
||||
to a GPU or TPU device for additional speedup,
|
||||
@ -28,7 +28,7 @@ This allows for a smooth workflow from development into production.
|
||||
In Dynamax the computationally expensive portion of a Linear State Space Model solver has been [jitted](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/models.py#L579).
|
||||
A more complex example comes from PyTensor which compiles a JAX function dynamically and then [jits the constructed function](https://github.com/pymc-devs/pytensor/blob/main/pytensor/link/jax/linker.py#L64).
|
||||
|
||||
## Single and Multi Computer Speedup Using Parallelization
|
||||
## Single and multi computer speedup using parallelization
|
||||
Another benefit of JAX is the simplicity of parallelizing computation using
|
||||
`pmap` and `vmap` function calls or decorators.
|
||||
In Dynamax state space models are parallelized with a [VMAP decorator](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/parallel_inference.py#L89)
|
||||
@ -43,7 +43,7 @@ such as Neural Networks or State Space models or others,
|
||||
or provide specific functionality such as optimization.
|
||||
Here are more specific examples of each pattern.
|
||||
|
||||
### Direct Usage
|
||||
### Direct usage
|
||||
Jax can be directly imported and utilized to build models “from scratch” as shown across this website,
|
||||
for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html)
|
||||
or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html).
|
||||
@ -51,7 +51,7 @@ This may be the best option if you are unable to find prebuilt code
|
||||
for your particular challenge, or if you're looking to reduce the number
|
||||
of dependencies in your codebase.
|
||||
|
||||
### Composable Domain Specific Libraries with JAX exposed
|
||||
### Composable domain specific libraries with JAX exposed
|
||||
Another common approach are packages that provide prebuilt functionality,
|
||||
whether it be model definition, or computation of some type.
|
||||
Combinations of these packages can then be mixed and matched for a full
|
||||
@ -68,7 +68,7 @@ With Dynamax parameters can be estimated using
|
||||
[Maximum Likelihood using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html)
|
||||
or full Bayesian Posterior can be estimating using [MCMC from Blackjax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_hmc.html)
|
||||
|
||||
### JAX Totally Hidden from Users
|
||||
### JAX totally hidden from users
|
||||
Other libraries opt to completely wrap JAX in their model specific API.
|
||||
An example is PyMC and [Pytensor](https://github.com/pymc-devs/pytensor),
|
||||
in which a user may never “see” JAX directly
|
||||
|
@ -162,7 +162,7 @@ possible. The `git rebase -i` command might be useful to this end.
|
||||
|
||||
(linting-and-type-checking)=
|
||||
|
||||
### Linting and Type-checking
|
||||
### Linting and type-checking
|
||||
|
||||
JAX uses [mypy](https://mypy.readthedocs.io/) and
|
||||
[ruff](https://docs.astral.sh/ruff/) to statically test code quality; the
|
||||
@ -186,7 +186,7 @@ fix the issues you can push new commits to your branch.
|
||||
|
||||
### Restricted test suite
|
||||
|
||||
Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This
|
||||
Once your PR has been reviewed, a JAX maintainer will mark it as `pull ready`. This
|
||||
will trigger a larger set of tests, including tests on GPU and TPU backends that are
|
||||
not available via standard GitHub CI. Detailed results of these tests are not publicly
|
||||
viewable, but the JAX maintainer assigned to your PR will communicate with you regarding
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Device Memory Profiling
|
||||
# Device memory profiling
|
||||
|
||||
<!--* freshness: { reviewed: '2024-03-08' } *-->
|
||||
|
||||
@ -9,7 +9,7 @@ profile, open the `memory_viewer` tab of the Tensorboard profiler for more
|
||||
detailed and understandable device memory usage.
|
||||
```
|
||||
|
||||
The JAX Device Memory Profiler allows us to explore how and why JAX programs are
|
||||
The JAX device memory profiler allows us to explore how and why JAX programs are
|
||||
using GPU or TPU memory. For example, it can be used to:
|
||||
|
||||
* Figure out which arrays and executables are in GPU memory at a given time, or
|
||||
|
@ -28,7 +28,7 @@ JAX glossary of terms
|
||||
able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`).
|
||||
|
||||
jaxpr
|
||||
Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that
|
||||
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
|
||||
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
|
||||
See :ref:`understanding-jaxprs` for more discussion and examples.
|
||||
|
||||
|
@ -23,7 +23,7 @@ Here is a suggested investigation strategy:
|
||||
2. Hourly recompilation while keeping XLA and JAX in sync.
|
||||
3. Final verification: maybe a manual check of a few commits (or a git bisect).
|
||||
|
||||
## Nightly investigation.
|
||||
## Nightly investigation
|
||||
|
||||
This can be done by using [JAX-Toolbox nightly
|
||||
containers](https://github.com/NVIDIA/JAX-Toolbox).
|
||||
@ -128,7 +128,7 @@ investigate hourly between 8-24 and 8-26. There was a smaller slowdown
|
||||
earlier, lets ignore it for this example. It would be only another
|
||||
hourly investigation between those dates.
|
||||
|
||||
## Hourly investigation.
|
||||
## Hourly investigation
|
||||
|
||||
This does a checkout of JAX and XLA at each hour between the 2 dates,
|
||||
rebuilds everything and runs the test. The scripts are structured
|
||||
|
@ -164,7 +164,7 @@ before (with two input vars, one for each element of the input tuple)
|
||||
|
||||
|
||||
|
||||
Constant Vars
|
||||
Constant vars
|
||||
-------------
|
||||
|
||||
Some values in jaxprs are constants, in that their value does not depend on the
|
||||
|
@ -6,7 +6,7 @@
|
||||
"id": "18AF5Ab4p6VL"
|
||||
},
|
||||
"source": [
|
||||
"# Training a Simple Neural Network, with PyTorch Data Loading\n",
|
||||
"# Training a simple neural network, with PyTorch data loading\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
|
||||
"\n",
|
||||
@ -261,7 +261,7 @@
|
||||
"id": "umJJGZCC2oKl"
|
||||
},
|
||||
"source": [
|
||||
"## Data Loading with PyTorch\n",
|
||||
"## Data loading with PyTorch\n",
|
||||
"\n",
|
||||
"JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays."
|
||||
]
|
||||
@ -494,7 +494,7 @@
|
||||
"id": "xxPd6Qw3Z98v"
|
||||
},
|
||||
"source": [
|
||||
"## Training Loop"
|
||||
"## Training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -14,7 +14,7 @@ kernelspec:
|
||||
|
||||
+++ {"id": "18AF5Ab4p6VL"}
|
||||
|
||||
# Training a Simple Neural Network, with PyTorch Data Loading
|
||||
# Training a simple neural network, with PyTorch data loading
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
@ -175,7 +175,7 @@ def update(params, x, y):
|
||||
|
||||
+++ {"id": "umJJGZCC2oKl"}
|
||||
|
||||
## Data Loading with PyTorch
|
||||
## Data loading with PyTorch
|
||||
|
||||
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays.
|
||||
|
||||
@ -245,7 +245,7 @@ test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
|
||||
|
||||
+++ {"id": "xxPd6Qw3Z98v"}
|
||||
|
||||
## Training Loop
|
||||
## Training loop
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: X2DnZo3iYj18
|
||||
|
@ -6,7 +6,7 @@
|
||||
"id": "TVT_MVvc02AA"
|
||||
},
|
||||
"source": [
|
||||
"# Generalized Convolutions in JAX\n",
|
||||
"# Generalized convolutions in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
@ -28,7 +28,7 @@
|
||||
"id": "ewZEn2X12-Ng"
|
||||
},
|
||||
"source": [
|
||||
"## Basic One-dimensional Convolution\n",
|
||||
"## Basic one-dimensional convolution\n",
|
||||
"\n",
|
||||
"Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:"
|
||||
]
|
||||
@ -91,7 +91,7 @@
|
||||
"id": "5ndvLDIH4rv6"
|
||||
},
|
||||
"source": [
|
||||
"## Basic N-dimensional Convolution\n",
|
||||
"## Basic N-dimensional convolution\n",
|
||||
"\n",
|
||||
"For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.\n",
|
||||
"\n",
|
||||
@ -160,7 +160,7 @@
|
||||
"id": "bxuUjFVG-v1h"
|
||||
},
|
||||
"source": [
|
||||
"## General Convolutions"
|
||||
"## General convolutions"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -14,7 +14,7 @@ kernelspec:
|
||||
|
||||
+++ {"id": "TVT_MVvc02AA"}
|
||||
|
||||
# Generalized Convolutions in JAX
|
||||
# Generalized convolutions in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
@ -31,7 +31,7 @@ For basic convolution operations, the `jax.numpy` and `jax.scipy` operations are
|
||||
|
||||
+++ {"id": "ewZEn2X12-Ng"}
|
||||
|
||||
## Basic One-dimensional Convolution
|
||||
## Basic one-dimensional convolution
|
||||
|
||||
Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:
|
||||
|
||||
@ -65,7 +65,7 @@ For more information, see the {func}`jax.numpy.convolve` documentation, or the d
|
||||
|
||||
+++ {"id": "5ndvLDIH4rv6"}
|
||||
|
||||
## Basic N-dimensional Convolution
|
||||
## Basic N-dimensional convolution
|
||||
|
||||
For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.
|
||||
|
||||
@ -105,7 +105,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l
|
||||
|
||||
+++ {"id": "bxuUjFVG-v1h"}
|
||||
|
||||
## General Convolutions
|
||||
## General convolutions
|
||||
|
||||
+++ {"id": "0pcn2LeS-03b"}
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
"id": "7XNMxdTwURqI"
|
||||
},
|
||||
"source": [
|
||||
"# External Callbacks in JAX\n",
|
||||
"# External callbacks in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
|
||||
]
|
||||
|
@ -13,7 +13,7 @@ kernelspec:
|
||||
|
||||
+++ {"id": "7XNMxdTwURqI"}
|
||||
|
||||
# External Callbacks in JAX
|
||||
# External callbacks in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
|
@ -36,7 +36,7 @@
|
||||
"id": "B_XlLLpcWjkA"
|
||||
},
|
||||
"source": [
|
||||
"# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n",
|
||||
"# Training a simple neural network, with tensorflow/datasets data loading\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
|
||||
"\n",
|
||||
@ -274,7 +274,7 @@
|
||||
"id": "umJJGZCC2oKl"
|
||||
},
|
||||
"source": [
|
||||
"## Data Loading with `tensorflow/datasets`\n",
|
||||
"## Data loading with `tensorflow/datasets`\n",
|
||||
"\n",
|
||||
"JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader."
|
||||
]
|
||||
@ -344,7 +344,7 @@
|
||||
"id": "xxPd6Qw3Z98v"
|
||||
},
|
||||
"source": [
|
||||
"## Training Loop"
|
||||
"## Training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -34,7 +34,7 @@ limitations under the License.
|
||||
|
||||
+++ {"id": "B_XlLLpcWjkA"}
|
||||
|
||||
# Training a Simple Neural Network, with tensorflow/datasets Data Loading
|
||||
# Training a simple neural network, with tensorflow/datasets data loading
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
@ -183,7 +183,7 @@ def update(params, x, y):
|
||||
|
||||
+++ {"id": "umJJGZCC2oKl"}
|
||||
|
||||
## Data Loading with `tensorflow/datasets`
|
||||
## Data loading with `tensorflow/datasets`
|
||||
|
||||
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader.
|
||||
|
||||
@ -229,7 +229,7 @@ print('Test:', test_images.shape, test_labels.shape)
|
||||
|
||||
+++ {"id": "xxPd6Qw3Z98v"}
|
||||
|
||||
## Training Loop
|
||||
## Training loop
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: X2DnZo3iYj18
|
||||
|
@ -6,7 +6,7 @@
|
||||
"id": "LQHmwePqryRU"
|
||||
},
|
||||
"source": [
|
||||
"# How to Think in JAX\n",
|
||||
"# How to think in JAX\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
||||
"\n",
|
||||
@ -23,7 +23,7 @@
|
||||
"source": [
|
||||
"## JAX vs. NumPy\n",
|
||||
"\n",
|
||||
"**Key Concepts:**\n",
|
||||
"**Key concepts:**\n",
|
||||
"\n",
|
||||
"- JAX provides a NumPy-inspired interface for convenience.\n",
|
||||
"- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n",
|
||||
@ -282,7 +282,7 @@
|
||||
"source": [
|
||||
"## NumPy, lax & XLA: JAX API layering\n",
|
||||
"\n",
|
||||
"**Key Concepts:**\n",
|
||||
"**Key concepts:**\n",
|
||||
"\n",
|
||||
"- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n",
|
||||
"- `jax.lax` is a lower-level API that is stricter and often more powerful.\n",
|
||||
@ -475,7 +475,7 @@
|
||||
"source": [
|
||||
"## To JIT or not to JIT\n",
|
||||
"\n",
|
||||
"**Key Concepts:**\n",
|
||||
"**Key concepts:**\n",
|
||||
"\n",
|
||||
"- By default JAX executes operations one at a time, in sequence.\n",
|
||||
"- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n",
|
||||
@ -675,7 +675,7 @@
|
||||
"source": [
|
||||
"## JIT mechanics: tracing and static variables\n",
|
||||
"\n",
|
||||
"**Key Concepts:**\n",
|
||||
"**Key concepts:**\n",
|
||||
"\n",
|
||||
"- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n",
|
||||
"\n",
|
||||
@ -932,9 +932,9 @@
|
||||
"id": "r-RCl_wD5lI7"
|
||||
},
|
||||
"source": [
|
||||
"## Static vs Traced Operations\n",
|
||||
"## Static vs traced operations\n",
|
||||
"\n",
|
||||
"**Key Concepts:**\n",
|
||||
"**Key concepts:**\n",
|
||||
"\n",
|
||||
"- Just as values can be either static or traced, operations can be static or traced.\n",
|
||||
"\n",
|
||||
|
@ -13,7 +13,7 @@ kernelspec:
|
||||
|
||||
+++ {"id": "LQHmwePqryRU"}
|
||||
|
||||
# How to Think in JAX
|
||||
# How to think in JAX
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
@ -25,7 +25,7 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b
|
||||
|
||||
## JAX vs. NumPy
|
||||
|
||||
**Key Concepts:**
|
||||
**Key concepts:**
|
||||
|
||||
- JAX provides a NumPy-inspired interface for convenience.
|
||||
- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
|
||||
@ -132,7 +132,7 @@ print(y)
|
||||
|
||||
## NumPy, lax & XLA: JAX API layering
|
||||
|
||||
**Key Concepts:**
|
||||
**Key concepts:**
|
||||
|
||||
- `jax.numpy` is a high-level wrapper that provides a familiar interface.
|
||||
- `jax.lax` is a lower-level API that is stricter and often more powerful.
|
||||
@ -215,7 +215,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op
|
||||
|
||||
## To JIT or not to JIT
|
||||
|
||||
**Key Concepts:**
|
||||
**Key concepts:**
|
||||
|
||||
- By default JAX executes operations one at a time, in sequence.
|
||||
- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
|
||||
@ -308,7 +308,7 @@ This is because the function generates an array whose shape is not known at comp
|
||||
|
||||
## JIT mechanics: tracing and static variables
|
||||
|
||||
**Key Concepts:**
|
||||
**Key concepts:**
|
||||
|
||||
- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.
|
||||
|
||||
@ -417,9 +417,9 @@ Understanding which values and operations will be static and which will be trace
|
||||
|
||||
+++ {"id": "r-RCl_wD5lI7"}
|
||||
|
||||
## Static vs Traced Operations
|
||||
## Static vs traced operations
|
||||
|
||||
**Key Concepts:**
|
||||
**Key concepts:**
|
||||
|
||||
- Just as values can be either static or traced, operations can be static or traced.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Persistent Compilation Cache
|
||||
# Persistent compilation cache
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-09' } *-->
|
||||
|
||||
|
@ -12,7 +12,7 @@ kernelspec:
|
||||
name: python3
|
||||
---
|
||||
|
||||
# Stateful Computations
|
||||
# Stateful computations
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user