diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index e0a440491..9416b16cd 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -11,7 +11,7 @@ and how it's used for computational speedup in other libraries. Below are examples of how JAX's features can be used to define accelerated computation across numerous domains and software packages. -## Gradient Computation +## Gradient computation Easy gradient calculation is a key feature of JAX. In the [JaxOpt library](https://github.com/google/jaxopt) value and grad is directly utilized for users in multiple optimization algorithms in [its source code](https://github.com/google/jaxopt/blob/main/jaxopt/_src/base.py#LL87C30-L87C44). @@ -19,7 +19,7 @@ Similarly the same Dynamax Optax pairing mentioned above is an example of gradients enabling estimation methods that were challenging historically [Maximum Likelihood Expectation using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html). -## Computational Speedup on a Single Core across Multiple Devices +## Computational speedup on a single core across multiple devices Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling. The same compiled code can then be sent to a CPU device, to a GPU or TPU device for additional speedup, @@ -28,7 +28,7 @@ This allows for a smooth workflow from development into production. In Dynamax the computationally expensive portion of a Linear State Space Model solver has been [jitted](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/models.py#L579). A more complex example comes from PyTensor which compiles a JAX function dynamically and then [jits the constructed function](https://github.com/pymc-devs/pytensor/blob/main/pytensor/link/jax/linker.py#L64). -## Single and Multi Computer Speedup Using Parallelization +## Single and multi computer speedup using parallelization Another benefit of JAX is the simplicity of parallelizing computation using `pmap` and `vmap` function calls or decorators. In Dynamax state space models are parallelized with a [VMAP decorator](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/parallel_inference.py#L89) @@ -43,7 +43,7 @@ such as Neural Networks or State Space models or others, or provide specific functionality such as optimization. Here are more specific examples of each pattern. -### Direct Usage +### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). @@ -51,7 +51,7 @@ This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. -### Composable Domain Specific Libraries with JAX exposed +### Composable domain specific libraries with JAX exposed Another common approach are packages that provide prebuilt functionality, whether it be model definition, or computation of some type. Combinations of these packages can then be mixed and matched for a full @@ -68,7 +68,7 @@ With Dynamax parameters can be estimated using [Maximum Likelihood using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html) or full Bayesian Posterior can be estimating using [MCMC from Blackjax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_hmc.html) -### JAX Totally Hidden from Users +### JAX totally hidden from users Other libraries opt to completely wrap JAX in their model specific API. An example is PyMC and [Pytensor](https://github.com/pymc-devs/pytensor), in which a user may never “see” JAX directly diff --git a/docs/contributing.md b/docs/contributing.md index 4aecf7153..d7fa6e9da 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -162,7 +162,7 @@ possible. The `git rebase -i` command might be useful to this end. (linting-and-type-checking)= -### Linting and Type-checking +### Linting and type-checking JAX uses [mypy](https://mypy.readthedocs.io/) and [ruff](https://docs.astral.sh/ruff/) to statically test code quality; the @@ -186,7 +186,7 @@ fix the issues you can push new commits to your branch. ### Restricted test suite -Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This +Once your PR has been reviewed, a JAX maintainer will mark it as `pull ready`. This will trigger a larger set of tests, including tests on GPU and TPU backends that are not available via standard GitHub CI. Detailed results of these tests are not publicly viewable, but the JAX maintainer assigned to your PR will communicate with you regarding diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index e4d871b78..a906c54d5 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,4 +1,4 @@ -# Device Memory Profiling +# Device memory profiling @@ -9,7 +9,7 @@ profile, open the `memory_viewer` tab of the Tensorboard profiler for more detailed and understandable device memory usage. ``` -The JAX Device Memory Profiler allows us to explore how and why JAX programs are +The JAX device memory profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it can be used to: * Figure out which arrays and executables are in GPU memory at a given time, or diff --git a/docs/glossary.rst b/docs/glossary.rst index 179c3c75d..a7668e9a0 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -28,7 +28,7 @@ JAX glossary of terms able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`). jaxpr - Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that + Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution. See :ref:`understanding-jaxprs` for more discussion and examples. diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 4affae3a6..9b712056e 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -23,7 +23,7 @@ Here is a suggested investigation strategy: 2. Hourly recompilation while keeping XLA and JAX in sync. 3. Final verification: maybe a manual check of a few commits (or a git bisect). -## Nightly investigation. +## Nightly investigation This can be done by using [JAX-Toolbox nightly containers](https://github.com/NVIDIA/JAX-Toolbox). @@ -128,7 +128,7 @@ investigate hourly between 8-24 and 8-26. There was a smaller slowdown earlier, lets ignore it for this example. It would be only another hourly investigation between those dates. -## Hourly investigation. +## Hourly investigation This does a checkout of JAX and XLA at each hour between the 2 dates, rebuilds everything and runs the test. The scripts are structured diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index 56be62162..d7b50dcb3 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -164,7 +164,7 @@ before (with two input vars, one for each element of the input tuple) -Constant Vars +Constant vars ------------- Some values in jaxprs are constants, in that their value does not depend on the diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index f0c157655..16e623d0f 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -6,7 +6,7 @@ "id": "18AF5Ab4p6VL" }, "source": [ - "# Training a Simple Neural Network, with PyTorch Data Loading\n", + "# Training a simple neural network, with PyTorch data loading\n", "\n", "\n", "\n", @@ -261,7 +261,7 @@ "id": "umJJGZCC2oKl" }, "source": [ - "## Data Loading with PyTorch\n", + "## Data loading with PyTorch\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays." ] @@ -494,7 +494,7 @@ "id": "xxPd6Qw3Z98v" }, "source": [ - "## Training Loop" + "## Training loop" ] }, { diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 2c53bb1e4..87533117e 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -14,7 +14,7 @@ kernelspec: +++ {"id": "18AF5Ab4p6VL"} -# Training a Simple Neural Network, with PyTorch Data Loading +# Training a simple neural network, with PyTorch data loading @@ -175,7 +175,7 @@ def update(params, x, y): +++ {"id": "umJJGZCC2oKl"} -## Data Loading with PyTorch +## Data loading with PyTorch JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays. @@ -245,7 +245,7 @@ test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets) +++ {"id": "xxPd6Qw3Z98v"} -## Training Loop +## Training loop ```{code-cell} ipython3 :id: X2DnZo3iYj18 diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index 0a8233530..5246e810d 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -6,7 +6,7 @@ "id": "TVT_MVvc02AA" }, "source": [ - "# Generalized Convolutions in JAX\n", + "# Generalized convolutions in JAX\n", "\n", "\n", "\n", @@ -28,7 +28,7 @@ "id": "ewZEn2X12-Ng" }, "source": [ - "## Basic One-dimensional Convolution\n", + "## Basic one-dimensional convolution\n", "\n", "Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:" ] @@ -91,7 +91,7 @@ "id": "5ndvLDIH4rv6" }, "source": [ - "## Basic N-dimensional Convolution\n", + "## Basic N-dimensional convolution\n", "\n", "For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.\n", "\n", @@ -160,7 +160,7 @@ "id": "bxuUjFVG-v1h" }, "source": [ - "## General Convolutions" + "## General convolutions" ] }, { diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 3de8f261a..2dec35847 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -14,7 +14,7 @@ kernelspec: +++ {"id": "TVT_MVvc02AA"} -# Generalized Convolutions in JAX +# Generalized convolutions in JAX @@ -31,7 +31,7 @@ For basic convolution operations, the `jax.numpy` and `jax.scipy` operations are +++ {"id": "ewZEn2X12-Ng"} -## Basic One-dimensional Convolution +## Basic one-dimensional convolution Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution: @@ -65,7 +65,7 @@ For more information, see the {func}`jax.numpy.convolve` documentation, or the d +++ {"id": "5ndvLDIH4rv6"} -## Basic N-dimensional Convolution +## Basic N-dimensional convolution For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions. @@ -105,7 +105,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l +++ {"id": "bxuUjFVG-v1h"} -## General Convolutions +## General convolutions +++ {"id": "0pcn2LeS-03b"} diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index bdf71004c..050a641a7 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -6,7 +6,7 @@ "id": "7XNMxdTwURqI" }, "source": [ - "# External Callbacks in JAX\n", + "# External callbacks in JAX\n", "\n", "" ] diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index 857eef42e..ab0a2fcd3 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "7XNMxdTwURqI"} -# External Callbacks in JAX +# External callbacks in JAX diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 95c00bf1e..7d353c924 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -36,7 +36,7 @@ "id": "B_XlLLpcWjkA" }, "source": [ - "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", + "# Training a simple neural network, with tensorflow/datasets data loading\n", "\n", "\n", "\n", @@ -274,7 +274,7 @@ "id": "umJJGZCC2oKl" }, "source": [ - "## Data Loading with `tensorflow/datasets`\n", + "## Data loading with `tensorflow/datasets`\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader." ] @@ -344,7 +344,7 @@ "id": "xxPd6Qw3Z98v" }, "source": [ - "## Training Loop" + "## Training loop" ] }, { diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 8f795484d..2f7ba3271 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -34,7 +34,7 @@ limitations under the License. +++ {"id": "B_XlLLpcWjkA"} -# Training a Simple Neural Network, with tensorflow/datasets Data Loading +# Training a simple neural network, with tensorflow/datasets data loading @@ -183,7 +183,7 @@ def update(params, x, y): +++ {"id": "umJJGZCC2oKl"} -## Data Loading with `tensorflow/datasets` +## Data loading with `tensorflow/datasets` JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader. @@ -229,7 +229,7 @@ print('Test:', test_images.shape, test_labels.shape) +++ {"id": "xxPd6Qw3Z98v"} -## Training Loop +## Training loop ```{code-cell} ipython3 :id: X2DnZo3iYj18 diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 1c1c9729b..e4f9d888e 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -6,7 +6,7 @@ "id": "LQHmwePqryRU" }, "source": [ - "# How to Think in JAX\n", + "# How to think in JAX\n", "\n", "\n", "\n", @@ -23,7 +23,7 @@ "source": [ "## JAX vs. NumPy\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- JAX provides a NumPy-inspired interface for convenience.\n", "- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n", @@ -282,7 +282,7 @@ "source": [ "## NumPy, lax & XLA: JAX API layering\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n", "- `jax.lax` is a lower-level API that is stricter and often more powerful.\n", @@ -475,7 +475,7 @@ "source": [ "## To JIT or not to JIT\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- By default JAX executes operations one at a time, in sequence.\n", "- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n", @@ -675,7 +675,7 @@ "source": [ "## JIT mechanics: tracing and static variables\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n", "\n", @@ -932,9 +932,9 @@ "id": "r-RCl_wD5lI7" }, "source": [ - "## Static vs Traced Operations\n", + "## Static vs traced operations\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- Just as values can be either static or traced, operations can be static or traced.\n", "\n", diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 14089fa36..16be7b9e9 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "LQHmwePqryRU"} -# How to Think in JAX +# How to think in JAX @@ -25,7 +25,7 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b ## JAX vs. NumPy -**Key Concepts:** +**Key concepts:** - JAX provides a NumPy-inspired interface for convenience. - Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays. @@ -132,7 +132,7 @@ print(y) ## NumPy, lax & XLA: JAX API layering -**Key Concepts:** +**Key concepts:** - `jax.numpy` is a high-level wrapper that provides a familiar interface. - `jax.lax` is a lower-level API that is stricter and often more powerful. @@ -215,7 +215,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op ## To JIT or not to JIT -**Key Concepts:** +**Key concepts:** - By default JAX executes operations one at a time, in sequence. - Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once. @@ -308,7 +308,7 @@ This is because the function generates an array whose shape is not known at comp ## JIT mechanics: tracing and static variables -**Key Concepts:** +**Key concepts:** - JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type. @@ -417,9 +417,9 @@ Understanding which values and operations will be static and which will be trace +++ {"id": "r-RCl_wD5lI7"} -## Static vs Traced Operations +## Static vs traced operations -**Key Concepts:** +**Key concepts:** - Just as values can be either static or traced, operations can be static or traced. diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index af20aa7bb..1d6a5d9b7 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,4 +1,4 @@ -# Persistent Compilation Cache +# Persistent compilation cache diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 5a8af2b74..4e4063a68 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -12,7 +12,7 @@ kernelspec: name: python3 --- -# Stateful Computations +# Stateful computations