mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #22271 from ayaka14732:lru-cache-6
PiperOrigin-RevId: 650203793
This commit is contained in:
commit
0d4e0ecf65
@ -8,36 +8,77 @@ when running the same or similar tasks repeatedly.
|
||||
|
||||
## Usage
|
||||
|
||||
### Quick start
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
x = jnp.zeros((2, 2))
|
||||
f(x)
|
||||
```
|
||||
|
||||
### Setting cache directory
|
||||
|
||||
The compilation cache is enabled when the
|
||||
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
|
||||
[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
|
||||
is set. This should be done prior to the first compilation. Set the location as
|
||||
follows:
|
||||
|
||||
```python
|
||||
import jax
|
||||
(1) Using environment variable
|
||||
|
||||
# Make sure this is called before jax runs any operations!
|
||||
jax.config.update("jax_compilation_cache_dir", "cache-location")
|
||||
In shell, before running the script:
|
||||
|
||||
```sh
|
||||
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
|
||||
```
|
||||
|
||||
See the sections below for more detail on `cache-location`.
|
||||
|
||||
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
|
||||
is an alternate way of setting `cache-location`.
|
||||
|
||||
### Local filesystem
|
||||
|
||||
`cache-location` can be a directory on the local filesystem. For example:
|
||||
Or on the top of the Python script:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
|
||||
import os
|
||||
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
|
||||
```
|
||||
|
||||
Note: the cache does not have an eviction mechanism implemented. If the
|
||||
cache-location is a directory in the local filesystem, its size will continue
|
||||
to grow unless files are manually deleted.
|
||||
(2) Using `jax.config.update()`
|
||||
|
||||
```python
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||
```
|
||||
|
||||
(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
|
||||
|
||||
```python
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
cc.set_cache_dir("/tmp/jax_cache")
|
||||
```
|
||||
|
||||
### Caching thresholds
|
||||
|
||||
* `jax_persistent_cache_min_compile_time_secs`: A computation will only be
|
||||
written to the persistent cache if the compilation time is longer than
|
||||
the specified value. It is defaulted to 1.0 second.
|
||||
|
||||
* `jax_persistent_cache_min_entry_size_bytes`: The minimum size (in bytes)
|
||||
of an entry that will be cached in the persistent compilation cache:
|
||||
|
||||
* `-1`: disable the size restriction and prevent overrides.
|
||||
|
||||
* Leave at default (`0`) to allow for overrides. The override will
|
||||
typically ensure that the minimum size is optimal for the file system
|
||||
being used for the cache.
|
||||
|
||||
* `> 0`: the actual minimum size desired; no overrides.
|
||||
|
||||
Note that both criteria need to be satisfied for a function to be cached.
|
||||
|
||||
### Google Cloud
|
||||
|
||||
@ -66,91 +107,64 @@ Cloud Storage (GCS) bucket. We recommend the following configuration:
|
||||
|
||||
* All encryption policies are supported.
|
||||
|
||||
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
|
||||
Assuming that `gs://jax-cache` is the GCS bucket, set cache location as
|
||||
follows:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
||||
The JAX compilation cache works by hashing a number of parameters to create a signature for a compiled function. These are:
|
||||
The cache key is the signature for a compiled function containing the
|
||||
following parameters:
|
||||
|
||||
* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed
|
||||
* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed
|
||||
|
||||
* The jaxlib version
|
||||
* The jaxlib version
|
||||
|
||||
* Relevant XLA compilation flags
|
||||
* Relevant XLA compilation flags
|
||||
|
||||
* Device configuration captured in general, by the number of devices and the topology of the devices.
|
||||
Currently for GPUs, the topology only contains a string representation of the GPU name
|
||||
* Device configuration captured in general, by the number of devices and the topology of the devices.
|
||||
Currently for GPUs, the topology only contains a string representation of the GPU name
|
||||
|
||||
* Compression algorithm used to compress the compiled executable
|
||||
* Compression algorithm used to compress the compiled executable
|
||||
|
||||
* Any custom hooks
|
||||
* A string produced by `jax._src.cache_key.custom_hook()`. This function can
|
||||
be reassigned to a user-defined function, so that the resulting string can
|
||||
be altered. By default, this function always returns an empty string.
|
||||
|
||||
When the signature for a function created using the parameters above matches
|
||||
that of a compiled function in the persistent cache the function will not be compiled,
|
||||
but will just be read and deserialized from the persistent cache.
|
||||
## Caching on multiple nodes
|
||||
|
||||
### Caching thresholds
|
||||
|
||||
There are two thresholds that control whether JAX caches an executable.
|
||||
These are `jax_persistent_cache_min_entry_size_bytes` and `jax_persistent_cache_min_compile_time_secs`.
|
||||
Only are at least `jax_persistent_cache_min_entry_size_bytes` large and take `jax_persistent_cache_min_compile_time_secs`
|
||||
long to compile will be cached. To cache every executable that is compiled,
|
||||
you can set the former to -1 and the latter to 0 as follows:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||
```
|
||||
|
||||
## Different Runtimes
|
||||
|
||||
Below we outline some observed behavior of the persistent compilation cache
|
||||
in a variety of different runtimes as it relates to which processes write to the cache.
|
||||
|
||||
### Single node with single process and single device
|
||||
|
||||
This is the simplest runtime and the only process that compiles and writes to the compilation cache is the singular process.
|
||||
The number of devices does not matter to the cache key in this setup, only the type of device does.
|
||||
|
||||
### Single node with single process and multiple devices
|
||||
|
||||
Once again the only process that compiles and writes to the compilation cache is the singular proess.
|
||||
The difference between this setup and the previous is that now the number of devices matters in addition to the type of device.
|
||||
|
||||
### Multiple process and multiple devices (on either single or multiple nodes)
|
||||
|
||||
In this runtime the first time a program is run (the persistent cache is cold / empty) all processes will compile,
|
||||
but only the process with rank 0 in the global communication group will write to the persistent cache.
|
||||
In subsequent runs, all processes will attempt to read from the persistent cache,
|
||||
so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS).
|
||||
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
|
||||
in subsequent runs as a result of a compilation cache miss.
|
||||
The first time a program is run (the persistent cache is cold / empty) all processes will compile,
|
||||
but only the process with rank 0 in the global communication group will write to the persistent cache.
|
||||
In subsequent runs, all processes will attempt to read from the persistent cache,
|
||||
so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS).
|
||||
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
|
||||
in subsequent runs as a result of a compilation cache miss.
|
||||
|
||||
## Logging cache activity
|
||||
|
||||
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
|
||||
While there is no singular canonical way of debugging and examining what's happening in the compilation cache,
|
||||
here are a few suggestions on how to begin.
|
||||
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
|
||||
Here are a few suggestions on how to begin.
|
||||
|
||||
Users can enable the logging of related source files by placing
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"
|
||||
```
|
||||
|
||||
on the top of the script.
|
||||
|
||||
### Examining cache misses
|
||||
|
||||
To merely examine and understand why there are cache misses JAX includes a configuration flag that
|
||||
To examine and understand why there are cache misses, JAX includes a configuration flag that
|
||||
enables the logging of all cache misses (including persistent compilation cache misses) with their explanations.
|
||||
Although currently, this is only implemented for tracing cache misses, the eventual goal is to
|
||||
explain all cache misses. This can be enabled by setting the following configuration.
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.config.update("jax_explain_cache_misses", True)
|
||||
```
|
||||
|
||||
@ -158,18 +172,18 @@ jax.config.update("jax_explain_cache_misses", True)
|
||||
|
||||
There are a couple of pitfalls that have currently been discovered:
|
||||
|
||||
* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided.
|
||||
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.
|
||||
* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided.
|
||||
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.
|
||||
|
||||
* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning.
|
||||
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
|
||||
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.
|
||||
* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning.
|
||||
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
|
||||
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.
|
||||
|
||||
### Working around custom_partitioning
|
||||
### Working around `custom_partitioning`
|
||||
|
||||
As mentioned, the compilation cache doesn't work with a function that is composed of primitives that implement custom_partitioning. However, it is possible to use shard_map to circumvent custom_partitioning for those primitives that do implement it and make the compilation cache work as expected:
|
||||
As mentioned, the compilation cache doesn't work with a function that is composed of primitives that implement `custom_partitioning`. However, it is possible to use shard_map to circumvent `custom_partitioning` for those primitives that do implement it and make the compilation cache work as expected:
|
||||
|
||||
Let's pretend we have a function `F` that implements a layernorm followed by a matrix multiplication using a primitive `LayerNorm` that implements custom_partitioning:
|
||||
Let's pretend we have a function `F` that implements a layernorm followed by a matrix multiplication using a primitive `LayerNorm` that implements `custom_partitioning`:
|
||||
|
||||
```python
|
||||
import jax
|
||||
@ -178,13 +192,13 @@ def F(x1, x2, gamma, beta):
|
||||
ln_out = LayerNorm(x1, gamma, beta)
|
||||
return ln_out @ x2
|
||||
```
|
||||
If we were to merely compile this function without shard_map, the cache key for `layernorm_matmul_without_shard_map` would be different everytime we ran the same code:
|
||||
If we were to merely compile this function without shard_map, the cache key for `layernorm_matmul_without_shard_map` would be different every time we ran the same code:
|
||||
|
||||
```python
|
||||
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)
|
||||
```
|
||||
|
||||
However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing custom_partitioning:
|
||||
However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing `custom_partitioning`:
|
||||
|
||||
```python
|
||||
import jax
|
||||
@ -199,5 +213,5 @@ ospecs = jax.sharding.PartitionSpec(...)
|
||||
mesh = jax.sharding.Mesh(...)
|
||||
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)
|
||||
```
|
||||
Note that the primitive that implements custom_partitioning must be wrapped in shard_map for this work around. It is insufficient to wrap the outer function `F` in shard_map.
|
||||
|
||||
Note that the primitive that implements `custom_partitioning` must be wrapped in shard_map for this work around. It is insufficient to wrap the outer function `F` in shard_map.
|
||||
|
Loading…
x
Reference in New Issue
Block a user