[JAX] Add support for sharing an LRU cache between multiple C++ jit-ted functions.

Adds a new CompiledFunctionCache object that can be passed to the CompiledFunction constructor. Multiple CompiledFunctions can share the same cache capacity.

This change is in preparation for adding jit decorators to many standard library functions. We do not want to drastically increase the number of cached computations, and having a cache shared between functions allows us to avoid this.

Also allow cache entries to persist past the lifetime of the enclosing jax.jit(f) call so long as `f` remains alive. This mirrors the behavior of the existing linear_util cache that JAX uses in Python.

PiperOrigin-RevId: 368664536
This commit is contained in:
Peter Hawkins 2021-04-15 10:15:51 -07:00 committed by jax authors
parent 0d4bcde7ca
commit 8a221d3d4d

View File

@ -356,6 +356,9 @@ class _BackendAndDeviceInfo(NamedTuple):
committed_to_device: bool
if lib._xla_extension_version >= 16:
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
def _cpp_jit(
fun: F,
static_argnums: Union[int, Iterable[int], None] = None,
@ -437,7 +440,6 @@ def _cpp_jit(
execute.func is xla._execute_compiled and # not trivial, not pmap
# Not supported: ShardedDeviceArray
all(xla.type_is_device_array(x) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
xla_executable, _, result_handlers = execute.args
@ -469,11 +471,18 @@ def _cpp_jit(
if lib._xla_extension_version < 14:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, static_argnums)
f_jitted = wraps(fun)(cpp_jitted_f)
else:
elif lib._xla_extension_version < 16:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames)
f_jitted = wraps(fun)(cpp_jitted_f)
else:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache)
f_jitted = wraps(fun)(cpp_jitted_f)
return f_jitted