diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 37afa2f59..0a5a89abe 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -168,6 +168,36 @@ so it is important for the persistent cache to be in a shared file system (eg: N If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile in subsequent runs as a result of a compilation cache miss. +### Pre-compiling multi-node programs on single node + +JAX can populate the compilation cache with compiled programs for multiple nodes +on a single node. Preparing the cache on a single node helps to decrease the costly +compilation time on a cluster. To compile and run multi-node programs on a single +node, users can create fake remote devices using +the `jax_mock_gpu_topology` configuration option. + +For instance, the snippet below instructs JAX to mock a cluster with four +nodes, each node running eight processes with each process attached to one GPU. + +```python +jax.config.update("jax_mock_gpu_topology", "4x8x1") +``` + +After populating the cache with this config, users can run the program +without recompilation on four nodes, eight processes per node, +one GPU per process. + +Important notes: + +* The process running the mocked program must have the same amount of GPUs + and the same GPU model as the nodes that would use the cache. For instance, + a mocked topology `8x4x2` must run in a process with two GPUs. + +* When running programs with mocked topology, the results of communications + with other nodes are undefined, so the outputs of JAX programs running + in mocked environments will likely be incorrect. + + ## Logging cache activity It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.