Fix doc typos

This commit is contained in:
rajasekharporeddy 2024-03-13 14:48:31 +05:30
parent 187b7aa8e6
commit 75da54e9f4
8 changed files with 10 additions and 10 deletions

View File

@ -554,7 +554,7 @@ ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,51
True
```
The values have been computed correctly for forward operation, however, the generated HLO modules shows an `all-gather` operation to replicate `x` on all devices, incurring large communication overhead.
The values have been computed correctly for forward operation, however, the generated HLO modules show an `all-gather` operation to replicate `x` on all devices, incurring large communication overhead.
As XLA does not have enough knowledge about the custom functions to shard input tensors, it decides to replicate them to produce correct values before making the custom call.

View File

@ -150,7 +150,7 @@ def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> No
return None
```
As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might reordered and asynchronous (see below).
As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might be reordered and asynchronous (see below).
### Sharp bits
Like most JAX APIs, `jax.debug.print` can cut you if you're not careful.

View File

@ -6,7 +6,7 @@ JAX Frequently Asked Questions (FAQ)
.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
We are collecting here answers to frequently asked questions.
We are collecting answers to frequently asked questions here.
Contributions welcome!
``jit`` changes the behavior of my function

View File

@ -1,4 +1,4 @@
# GPU peformance tips
# GPU performance tips
This document focuses on performance tips for neural network workloads

View File

@ -306,7 +306,7 @@ and :py:func:`jax.lax.fori_loop`
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
In the above signature, “C” stands for the type of a the loop “carry” value.
In the above signature, “C” stands for the type of the loop “carry” value.
For example, here is an example fori loop
>>> import numpy as np

View File

@ -50,7 +50,7 @@ active for a portion of your script, you can shut it down by calling
`jax.profiler.stop_server()`.
Once the script is running and after the profiler server has started, we can
manually capture an trace by running:
manually capture and trace by running:
```bash
$ python -m jax.collect_profile <port> <duration_in_ms>
```
@ -218,7 +218,7 @@ from a running program.
### Concurrent kernel tracing on GPU
By default, traces captured on GPU in a mode that prevents CUDA kernels from
By default, traces are captured on GPU in a mode that prevents CUDA kernels from
running concurrently. This allows for more accurate kernel timings, but removes
any concurrency between streams (for example, between compute and
communication). To enable concurrent kernel tracing, set the environment

View File

@ -280,7 +280,7 @@ class RegisteredSpecial2(Special):
show_example(RegisteredSpecial2(1., 2.))
```
When defining an unflattening functions, in general `children` should contain all the
When defining unflattening functions, in general `children` should contain all the
dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while
`aux_data` should contain all the static elements that will be rolled into the `treedef`
structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash

View File

@ -581,7 +581,7 @@ class UnexpectedTracerError(JAXTypeError):
code by including information about each stage. Respectively:
1. The name of the transformed function (``side_effecting``) and which
transform kicked of the trace :func:`~jax.jit`).
transform kicked off the trace :func:`~jax.jit`).
2. A reconstructed stack trace of where the leaked Tracer was created,
which includes where the transformed function was called.
(``When the Tracer was created, the final 5 stack frames were...``).
@ -589,7 +589,7 @@ class UnexpectedTracerError(JAXTypeError):
the leaked Tracer.
4. The leak location is not included in the error message because it is
difficult to pin down! JAX can only tell you what the leaked value
looks like (what shape is has and where it was created) and what
looks like (what shape it has and where it was created) and what
boundary it was leaked over (the name of the transformation and the
name of the transformed function).
5. The current error's stack trace points to where the value is used.