Jax persistent compilation cache user guide.

This user guide covers using the cache on local filesystems
and Google Cloud.

PiperOrigin-RevId: 623236335
This commit is contained in:
jax authors 2024-04-09 11:47:22 -07:00
parent 828e60cc03
commit afb775c168
2 changed files with 75 additions and 1 deletions

View File

@ -0,0 +1,74 @@
# Persistent Compilation Cache
JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.
## Usage
The compilation cache is enabled when the
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
```
import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
```
See the sections below for more detail on `cache-location`.
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
is an alternate way of setting `cache-location`.
### Local filesystem
`cache-location` can be a directory on the local filesystem. For example:
```
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
```
Note: the cache does not have an eviction mechanism implemented. If the
cache-location is a directory in the local filesystem, its size will continue
to grow unless files are manually deleted.
### Google Cloud
When running on Google Cloud, the compilation cache can be placed on a Google
Cloud Storage (GCS) bucket. We recommend the following configuration:
* Create the bucket in the same region as where the workload will run.
* Create the bucket in the same project as the workloads VM(s). Ensure that
permissions are set so that the VM(s) can write to the bucket.
* There is no need for replication for smaller workloads. Larger workloads
could benefit from replication.
* Use “Standard” for the default storage class for the bucket.
* Set the soft delete policy to its shortest: 7 days.
* Set the object lifecycle to the expected duration of the workload run.
For example, if the workload is expected to run for 10 days, set the object
lifecycle to 10 days. That should cover restarts that occur during the entire
run. Use `age` for the lifecycle condition and `Delete` for the action. See
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
for details. If the object lifecycle is not set, the cache will continue to
grow since there is no eviction mechanism implemented.
* All encryption policies are supported.
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
follows:
```
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
```

View File

@ -15,7 +15,7 @@ or deployed codebases.
device_memory_profiling
debugging/index
gpu_performance_tips
persistent_compilation_cache
.. toctree::
:maxdepth: 1