mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #18600 from nouiz:doc_compilation_cache
PiperOrigin-RevId: 584068904
This commit is contained in:
commit
ab9c973031
13
docs/jax.experimental.compilation_cache.rst
Normal file
13
docs/jax.experimental.compilation_cache.rst
Normal file
@ -0,0 +1,13 @@
|
||||
``jax.experimental.compilation_cache`` module
|
||||
===============================================
|
||||
|
||||
JAX disk compilation cache.
|
||||
|
||||
.. automodule:: jax.experimental.compilation_cache.compilation_cache
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autofunction:: is_initialized
|
||||
.. autofunction:: initialize_cache
|
||||
.. autofunction:: reset_cache
|
@ -23,6 +23,7 @@ Experimental Modules
|
||||
jax.experimental.jet
|
||||
jax.experimental.custom_partitioning
|
||||
jax.experimental.multihost_utils
|
||||
jax.experimental.compilation_cache
|
||||
|
||||
Experimental APIs
|
||||
-----------------
|
||||
|
@ -44,6 +44,9 @@ def initialize_cache(path):
|
||||
|
||||
Will throw an assertion error if called a second time with a different path.
|
||||
|
||||
Only works for GPU and TPU backend as the CPU backend don't
|
||||
implement yet the serialization API.
|
||||
|
||||
Args:
|
||||
path: path for the cache directory.
|
||||
"""
|
||||
@ -122,6 +125,8 @@ def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
|
||||
|
||||
|
||||
def is_initialized():
|
||||
"""Return True is there is a cache initialized.
|
||||
"""
|
||||
return _cache is not None
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user