rocm_jax/docs/persistent_compilation_cache.md
Jaroslav Sevcik cb2ab409f6 Add docs
2024-12-13 15:29:56 +00:00

280 lines
10 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Persistent compilation cache
<!--* freshness: { reviewed: '2024-11-07' } *-->
JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.
Note: if the compilation cache is not on a local filesystem,
[etils](https://pypi.org/project/etils/) needs to be installed.
```python
pip install etils
```
## Usage
### Quick start
```python
import jax
import jax.numpy as jnp
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")
@jax.jit
def f(x):
return x + 1
x = jnp.zeros((2, 2))
f(x)
```
### Setting cache directory
The compilation cache is enabled when the
[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
(1) Using environment variable
In shell, before running the script:
```sh
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
```
Or on the top of the Python script:
```python
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
```
(2) Using `jax.config.update()`
```python
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
```
(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
```python
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")
```
### Caching thresholds
* `jax_persistent_cache_min_compile_time_secs`: A computation will only be
written to the persistent cache if the compilation time is longer than
the specified value. It is defaulted to 1.0 second.
* `jax_persistent_cache_min_entry_size_bytes`: The minimum size (in bytes)
of an entry that will be cached in the persistent compilation cache:
* `-1`: disable the size restriction and prevent overrides.
* Leave at default (`0`) to allow for overrides. The override will
typically ensure that the minimum size is optimal for the file system
being used for the cache.
* `> 0`: the actual minimum size desired; no overrides.
Note that both criteria need to be satisfied for a function to be cached.
### Additional caching
XLA supports additional caching mechanism which can be enabled alongside JAX's
persistent compilation cache to further improve recompilation time.
* `jax_persistent_cache_enable_xla_caches`: Possible values:
* `all`: enable all XLA caching features
* `none`: don't enable any extra XLA caching features
* `xla_gpu_kernel_cache_file`: only enable the kernel cache
* `xla_gpu_per_fusion_autotune_cache_dir`: (default value) only enable the
autotuning cache
### Google Cloud
When running on Google Cloud, the compilation cache can be placed on a Google
Cloud Storage (GCS) bucket. We recommend the following configuration:
* Create the bucket in the same region as where the workload will run.
* Create the bucket in the same project as the workloads VM(s). Ensure that
permissions are set so that the VM(s) can write to the bucket.
* There is no need for replication for smaller workloads. Larger workloads
could benefit from replication.
* Use “Standard” for the default storage class for the bucket.
* Set the soft delete policy to its shortest: 7 days.
* Set the object lifecycle to the expected duration of the workload run.
For example, if the workload is expected to run for 10 days, set the object
lifecycle to 10 days. That should cover restarts that occur during the entire
run. Use `age` for the lifecycle condition and `Delete` for the action. See
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
for details. If the object lifecycle is not set, the cache will continue to
grow since there is no eviction mechanism implemented.
* All encryption policies are supported.
Assuming that `gs://jax-cache` is the GCS bucket, set cache location as
follows:
```python
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
```
## How it works
The cache key is the signature for a compiled function containing the
following parameters:
* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed
* The jaxlib version
* Relevant XLA compilation flags
* Device configuration captured in general, by the number of devices and the topology of the devices.
Currently for GPUs, the topology only contains a string representation of the GPU name
* Compression algorithm used to compress the compiled executable
* A string produced by `jax._src.cache_key.custom_hook()`. This function can
be reassigned to a user-defined function, so that the resulting string can
be altered. By default, this function always returns an empty string.
## Caching on multiple nodes
The first time a program is run (the persistent cache is cold / empty) all processes will compile,
but only the process with rank 0 in the global communication group will write to the persistent cache.
In subsequent runs, all processes will attempt to read from the persistent cache,
so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS).
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
in subsequent runs as a result of a compilation cache miss.
### Pre-compiling multi-node programs on single node
JAX can populate the compilation cache with compiled programs for multiple nodes
on a single node. Preparing the cache on a single node helps to decrease the costly
compilation time on a cluster. To compile and run multi-node programs on a single
node, users can create fake remote devices using
the `jax_mock_gpu_topology` configuration option.
For instance, the snippet below instructs JAX to mock a cluster with four
nodes, each node running eight processes with each process attached to one GPU.
```python
jax.config.update("jax_mock_gpu_topology", "4x8x1")
```
After populating the cache with this config, users can run the program
without recompilation on four nodes, eight processes per node,
one GPU per process.
Important notes:
* The process running the mocked program must have the same amount of GPUs
and the same GPU model as the nodes that would use the cache. For instance,
a mocked topology `8x4x2` must run in a process with two GPUs.
* When running programs with mocked topology, the results of communications
with other nodes are undefined, so the outputs of JAX programs running
in mocked environments will likely be incorrect.
## Logging cache activity
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
Here are a few suggestions on how to begin.
Users can enable the logging of related source files by placing
```python
import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"
```
on the top of the script. Alternatively, you can change the global jax logging level with
```python
import os
os.environ["JAX_LOGGING_LEVEL"] = "DEBUG"
# or locally with
jax.config.update("jax_logging_level", "DEBUG")
```
### Examining cache misses
To examine and understand why there are cache misses, JAX includes a configuration flag that
enables the logging of all cache misses (including persistent compilation cache misses) with their explanations.
Although currently, this is only implemented for tracing cache misses, the eventual goal is to
explain all cache misses. This can be enabled by setting the following configuration.
```python
jax.config.update("jax_explain_cache_misses", True)
```
## Pitfalls
There are a couple of pitfalls that have currently been discovered:
* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided.
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.
* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning.
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.
### Working around `custom_partitioning`
As mentioned, the compilation cache doesn't work with a function that is composed of primitives that implement `custom_partitioning`. However, it is possible to use shard_map to circumvent `custom_partitioning` for those primitives that do implement it and make the compilation cache work as expected:
Let's pretend we have a function `F` that implements a layernorm followed by a matrix multiplication using a primitive `LayerNorm` that implements `custom_partitioning`:
```python
import jax
def F(x1, x2, gamma, beta):
ln_out = LayerNorm(x1, gamma, beta)
return ln_out @ x2
```
If we were to merely compile this function without shard_map, the cache key for `layernorm_matmul_without_shard_map` would be different every time we ran the same code:
```python
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)
```
However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing `custom_partitioning`:
```python
import jax
from jax.experimental.shard_map import shard_map
def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta)
return ln_out @ x2
ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)
```
Note that the primitive that implements `custom_partitioning` must be wrapped in shard_map for this work around. It is insufficient to wrap the outer function `F` in shard_map.