This commit is contained in:
Jaroslav Sevcik 2024-12-13 15:29:56 +00:00
parent a123d4e39e
commit cb2ab409f6

View File

@ -168,6 +168,36 @@ so it is important for the persistent cache to be in a shared file system (eg: N
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
in subsequent runs as a result of a compilation cache miss.
### Pre-compiling multi-node programs on single node
JAX can populate the compilation cache with compiled programs for multiple nodes
on a single node. Preparing the cache on a single node helps to decrease the costly
compilation time on a cluster. To compile and run multi-node programs on a single
node, users can create fake remote devices using
the `jax_mock_gpu_topology` configuration option.
For instance, the snippet below instructs JAX to mock a cluster with four
nodes, each node running eight processes with each process attached to one GPU.
```python
jax.config.update("jax_mock_gpu_topology", "4x8x1")
```
After populating the cache with this config, users can run the program
without recompilation on four nodes, eight processes per node,
one GPU per process.
Important notes:
* The process running the mocked program must have the same amount of GPUs
and the same GPU model as the nodes that would use the cache. For instance,
a mocked topology `8x4x2` must run in a process with two GPUs.
* When running programs with mocked topology, the results of communications
with other nodes are undefined, so the outputs of JAX programs running
in mocked environments will likely be incorrect.
## Logging cache activity
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.