Merge pull request #22271 from ayaka14732:lru-cache-6

PiperOrigin-RevId: 650203793
This commit is contained in:
jax authors 2024-07-08 04:39:58 -07:00
commit 0d4e0ecf65

View File

@ -8,36 +8,77 @@ when running the same or similar tasks repeatedly.
## Usage
### Quick start
```python
import jax
import jax.numpy as jnp
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
@jax.jit
def f(x):
return x + 1
x = jnp.zeros((2, 2))
f(x)
```
### Setting cache directory
The compilation cache is enabled when the
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
```python
import jax
(1) Using environment variable
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
In shell, before running the script:
```sh
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
```
See the sections below for more detail on `cache-location`.
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
is an alternate way of setting `cache-location`.
### Local filesystem
`cache-location` can be a directory on the local filesystem. For example:
Or on the top of the Python script:
```python
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
```
Note: the cache does not have an eviction mechanism implemented. If the
cache-location is a directory in the local filesystem, its size will continue
to grow unless files are manually deleted.
(2) Using `jax.config.update()`
```python
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
```
(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
```python
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")
```
### Caching thresholds
* `jax_persistent_cache_min_compile_time_secs`: A computation will only be
written to the persistent cache if the compilation time is longer than
the specified value. It is defaulted to 1.0 second.
* `jax_persistent_cache_min_entry_size_bytes`: The minimum size (in bytes)
of an entry that will be cached in the persistent compilation cache:
* `-1`: disable the size restriction and prevent overrides.
* Leave at default (`0`) to allow for overrides. The override will
typically ensure that the minimum size is optimal for the file system
being used for the cache.
* `> 0`: the actual minimum size desired; no overrides.
Note that both criteria need to be satisfied for a function to be cached.
### Google Cloud
@ -66,91 +107,64 @@ Cloud Storage (GCS) bucket. We recommend the following configuration:
* All encryption policies are supported.
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
Assuming that `gs://jax-cache` is the GCS bucket, set cache location as
follows:
```python
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
```
## How it works
The JAX compilation cache works by hashing a number of parameters to create a signature for a compiled function. These are:
The cache key is the signature for a compiled function containing the
following parameters:
* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed
* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed
* The jaxlib version
* The jaxlib version
* Relevant XLA compilation flags
* Relevant XLA compilation flags
* Device configuration captured in general, by the number of devices and the topology of the devices.
Currently for GPUs, the topology only contains a string representation of the GPU name
* Device configuration captured in general, by the number of devices and the topology of the devices.
Currently for GPUs, the topology only contains a string representation of the GPU name
* Compression algorithm used to compress the compiled executable
* Compression algorithm used to compress the compiled executable
* Any custom hooks
* A string produced by `jax._src.cache_key.custom_hook()`. This function can
be reassigned to a user-defined function, so that the resulting string can
be altered. By default, this function always returns an empty string.
When the signature for a function created using the parameters above matches
that of a compiled function in the persistent cache the function will not be compiled,
but will just be read and deserialized from the persistent cache.
## Caching on multiple nodes
### Caching thresholds
There are two thresholds that control whether JAX caches an executable.
These are `jax_persistent_cache_min_entry_size_bytes` and `jax_persistent_cache_min_compile_time_secs`.
Only are at least `jax_persistent_cache_min_entry_size_bytes` large and take `jax_persistent_cache_min_compile_time_secs`
long to compile will be cached. To cache every executable that is compiled,
you can set the former to -1 and the latter to 0 as follows:
```python
import jax
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
```
## Different Runtimes
Below we outline some observed behavior of the persistent compilation cache
in a variety of different runtimes as it relates to which processes write to the cache.
### Single node with single process and single device
This is the simplest runtime and the only process that compiles and writes to the compilation cache is the singular process.
The number of devices does not matter to the cache key in this setup, only the type of device does.
### Single node with single process and multiple devices
Once again the only process that compiles and writes to the compilation cache is the singular proess.
The difference between this setup and the previous is that now the number of devices matters in addition to the type of device.
### Multiple process and multiple devices (on either single or multiple nodes)
In this runtime the first time a program is run (the persistent cache is cold / empty) all processes will compile,
but only the process with rank 0 in the global communication group will write to the persistent cache.
In subsequent runs, all processes will attempt to read from the persistent cache,
so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS).
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
in subsequent runs as a result of a compilation cache miss.
The first time a program is run (the persistent cache is cold / empty) all processes will compile,
but only the process with rank 0 in the global communication group will write to the persistent cache.
In subsequent runs, all processes will attempt to read from the persistent cache,
so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS).
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
in subsequent runs as a result of a compilation cache miss.
## Logging cache activity
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
While there is no singular canonical way of debugging and examining what's happening in the compilation cache,
here are a few suggestions on how to begin.
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
Here are a few suggestions on how to begin.
Users can enable the logging of related source files by placing
```python
import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"
```
on the top of the script.
### Examining cache misses
To merely examine and understand why there are cache misses JAX includes a configuration flag that
To examine and understand why there are cache misses, JAX includes a configuration flag that
enables the logging of all cache misses (including persistent compilation cache misses) with their explanations.
Although currently, this is only implemented for tracing cache misses, the eventual goal is to
explain all cache misses. This can be enabled by setting the following configuration.
```python
import jax
jax.config.update("jax_explain_cache_misses", True)
```
@ -158,18 +172,18 @@ jax.config.update("jax_explain_cache_misses", True)
There are a couple of pitfalls that have currently been discovered:
* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided.
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.
* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided.
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.
* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning.
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.
* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning.
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.
### Working around custom_partitioning
### Working around `custom_partitioning`
As mentioned, the compilation cache doesn't work with a function that is composed of primitives that implement custom_partitioning. However, it is possible to use shard_map to circumvent custom_partitioning for those primitives that do implement it and make the compilation cache work as expected:
As mentioned, the compilation cache doesn't work with a function that is composed of primitives that implement `custom_partitioning`. However, it is possible to use shard_map to circumvent `custom_partitioning` for those primitives that do implement it and make the compilation cache work as expected:
Let's pretend we have a function `F` that implements a layernorm followed by a matrix multiplication using a primitive `LayerNorm` that implements custom_partitioning:
Let's pretend we have a function `F` that implements a layernorm followed by a matrix multiplication using a primitive `LayerNorm` that implements `custom_partitioning`:
```python
import jax
@ -178,13 +192,13 @@ def F(x1, x2, gamma, beta):
ln_out = LayerNorm(x1, gamma, beta)
return ln_out @ x2
```
If we were to merely compile this function without shard_map, the cache key for `layernorm_matmul_without_shard_map` would be different everytime we ran the same code:
If we were to merely compile this function without shard_map, the cache key for `layernorm_matmul_without_shard_map` would be different every time we ran the same code:
```python
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)
```
However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing custom_partitioning:
However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing `custom_partitioning`:
```python
import jax
@ -199,5 +213,5 @@ ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)
```
Note that the primitive that implements custom_partitioning must be wrapped in shard_map for this work around. It is insufficient to wrap the outer function `F` in shard_map.
Note that the primitive that implements `custom_partitioning` must be wrapped in shard_map for this work around. It is insufficient to wrap the outer function `F` in shard_map.