Reduce jax.jit dispatch overhead by avoiding directly comparing python objects

Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state.

PiperOrigin-RevId: 456667829
This commit is contained in:
Kuangyuan Chen 2022-06-22 20:04:05 -07:00 committed by jax authors
parent 1908da33af
commit dc1c519547

View File

@ -507,13 +507,30 @@ class _ThreadLocalExtraJitContext(NamedTuple):
dynamic_shapes: bool = False
class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to duduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
IDs match.
"""
def __init__(self):
self.canonicalize = functools.lru_cache(128)(lambda x: x)
_thread_local_state_cache = _ThreadLocalStateCache()
def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
# minimum jaxlib version is past version 70
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
tls.extra_jit_context = context._replace(**kw)
tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
# TODO(mattjj): remove all uses of this flag