rocm_jax/docs/persistent_compilation_cache.md
jax authors fc1e1d4a65 Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
2024-06-21 14:50:49 -07:00

77 lines
2.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Persistent Compilation Cache
<!--* freshness: { reviewed: '2024-04-09' } *-->
JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.
## Usage
The compilation cache is enabled when the
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
```
import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
```
See the sections below for more detail on `cache-location`.
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
is an alternate way of setting `cache-location`.
### Local filesystem
`cache-location` can be a directory on the local filesystem. For example:
```
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
```
Note: the cache does not have an eviction mechanism implemented. If the
cache-location is a directory in the local filesystem, its size will continue
to grow unless files are manually deleted.
### Google Cloud
When running on Google Cloud, the compilation cache can be placed on a Google
Cloud Storage (GCS) bucket. We recommend the following configuration:
* Create the bucket in the same region as where the workload will run.
* Create the bucket in the same project as the workloads VM(s). Ensure that
permissions are set so that the VM(s) can write to the bucket.
* There is no need for replication for smaller workloads. Larger workloads
could benefit from replication.
* Use “Standard” for the default storage class for the bucket.
* Set the soft delete policy to its shortest: 7 days.
* Set the object lifecycle to the expected duration of the workload run.
For example, if the workload is expected to run for 10 days, set the object
lifecycle to 10 days. That should cover restarts that occur during the entire
run. Use `age` for the lifecycle condition and `Delete` for the action. See
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
for details. If the object lifecycle is not set, the cache will continue to
grow since there is no eviction mechanism implemented.
* All encryption policies are supported.
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
follows:
```
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
```