mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] Add support for sharing an LRU cache between multiple C++ jit-ted functions.
Adds a new CompiledFunctionCache object that can be passed to the CompiledFunction constructor. Multiple CompiledFunctions can share the same cache capacity. This change is in preparation for adding jit decorators to many standard library functions. We do not want to drastically increase the number of cached computations, and having a cache shared between functions allows us to avoid this. Also allow cache entries to persist past the lifetime of the enclosing jax.jit(f) call so long as `f` remains alive. This mirrors the behavior of the existing linear_util cache that JAX uses in Python. PiperOrigin-RevId: 368664536
This commit is contained in:
parent
0d4bcde7ca
commit
8a221d3d4d
@ -356,6 +356,9 @@ class _BackendAndDeviceInfo(NamedTuple):
|
||||
committed_to_device: bool
|
||||
|
||||
|
||||
if lib._xla_extension_version >= 16:
|
||||
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
|
||||
|
||||
def _cpp_jit(
|
||||
fun: F,
|
||||
static_argnums: Union[int, Iterable[int], None] = None,
|
||||
@ -437,7 +440,6 @@ def _cpp_jit(
|
||||
execute.func is xla._execute_compiled and # not trivial, not pmap
|
||||
# Not supported: ShardedDeviceArray
|
||||
all(xla.type_is_device_array(x) for x in out_flat))
|
||||
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
xla_executable, _, result_handlers = execute.args
|
||||
@ -469,11 +471,18 @@ def _cpp_jit(
|
||||
if lib._xla_extension_version < 14:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, static_argnums)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
else:
|
||||
elif lib._xla_extension_version < 16:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
else:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames,
|
||||
donate_argnums=donate_argnums,
|
||||
cache=_cpp_jit_cache)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
|
||||
return f_jitted
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user