mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix doc typos
This commit is contained in:
parent
187b7aa8e6
commit
75da54e9f4
@ -554,7 +554,7 @@ ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,51
|
||||
True
|
||||
```
|
||||
|
||||
The values have been computed correctly for forward operation, however, the generated HLO modules shows an `all-gather` operation to replicate `x` on all devices, incurring large communication overhead.
|
||||
The values have been computed correctly for forward operation, however, the generated HLO modules show an `all-gather` operation to replicate `x` on all devices, incurring large communication overhead.
|
||||
|
||||
As XLA does not have enough knowledge about the custom functions to shard input tensors, it decides to replicate them to produce correct values before making the custom call.
|
||||
|
||||
|
@ -150,7 +150,7 @@ def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> No
|
||||
return None
|
||||
```
|
||||
|
||||
As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might reordered and asynchronous (see below).
|
||||
As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might be reordered and asynchronous (see below).
|
||||
|
||||
### Sharp bits
|
||||
Like most JAX APIs, `jax.debug.print` can cut you if you're not careful.
|
||||
|
@ -6,7 +6,7 @@ JAX Frequently Asked Questions (FAQ)
|
||||
|
||||
.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
|
||||
|
||||
We are collecting here answers to frequently asked questions.
|
||||
We are collecting answers to frequently asked questions here.
|
||||
Contributions welcome!
|
||||
|
||||
``jit`` changes the behavior of my function
|
||||
|
@ -1,4 +1,4 @@
|
||||
# GPU peformance tips
|
||||
# GPU performance tips
|
||||
|
||||
This document focuses on performance tips for neural network workloads
|
||||
|
||||
|
@ -306,7 +306,7 @@ and :py:func:`jax.lax.fori_loop`
|
||||
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
|
||||
|
||||
|
||||
In the above signature, “C” stands for the type of a the loop “carry” value.
|
||||
In the above signature, “C” stands for the type of the loop “carry” value.
|
||||
For example, here is an example fori loop
|
||||
|
||||
>>> import numpy as np
|
||||
|
@ -50,7 +50,7 @@ active for a portion of your script, you can shut it down by calling
|
||||
`jax.profiler.stop_server()`.
|
||||
|
||||
Once the script is running and after the profiler server has started, we can
|
||||
manually capture an trace by running:
|
||||
manually capture and trace by running:
|
||||
```bash
|
||||
$ python -m jax.collect_profile <port> <duration_in_ms>
|
||||
```
|
||||
@ -218,7 +218,7 @@ from a running program.
|
||||
|
||||
### Concurrent kernel tracing on GPU
|
||||
|
||||
By default, traces captured on GPU in a mode that prevents CUDA kernels from
|
||||
By default, traces are captured on GPU in a mode that prevents CUDA kernels from
|
||||
running concurrently. This allows for more accurate kernel timings, but removes
|
||||
any concurrency between streams (for example, between compute and
|
||||
communication). To enable concurrent kernel tracing, set the environment
|
||||
|
@ -280,7 +280,7 @@ class RegisteredSpecial2(Special):
|
||||
show_example(RegisteredSpecial2(1., 2.))
|
||||
```
|
||||
|
||||
When defining an unflattening functions, in general `children` should contain all the
|
||||
When defining unflattening functions, in general `children` should contain all the
|
||||
dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while
|
||||
`aux_data` should contain all the static elements that will be rolled into the `treedef`
|
||||
structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash
|
||||
|
@ -581,7 +581,7 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
code by including information about each stage. Respectively:
|
||||
|
||||
1. The name of the transformed function (``side_effecting``) and which
|
||||
transform kicked of the trace :func:`~jax.jit`).
|
||||
transform kicked off the trace :func:`~jax.jit`).
|
||||
2. A reconstructed stack trace of where the leaked Tracer was created,
|
||||
which includes where the transformed function was called.
|
||||
(``When the Tracer was created, the final 5 stack frames were...``).
|
||||
@ -589,7 +589,7 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
the leaked Tracer.
|
||||
4. The leak location is not included in the error message because it is
|
||||
difficult to pin down! JAX can only tell you what the leaked value
|
||||
looks like (what shape is has and where it was created) and what
|
||||
looks like (what shape it has and where it was created) and what
|
||||
boundary it was leaked over (the name of the transformation and the
|
||||
name of the transformed function).
|
||||
5. The current error's stack trace points to where the value is used.
|
||||
|
Loading…
x
Reference in New Issue
Block a user