mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Reduce jax.jit dispatch overhead by avoiding directly comparing python objects
Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state. PiperOrigin-RevId: 456667829
This commit is contained in:
parent
1908da33af
commit
dc1c519547
@ -507,13 +507,30 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
dynamic_shapes: bool = False
|
||||
|
||||
|
||||
class _ThreadLocalStateCache(threading.local):
|
||||
""""A thread local cache for _ThreadLocalExtraJitContext
|
||||
|
||||
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
|
||||
incurring dispatch overhead for comparing this python object during jit calls.
|
||||
We want to duduplicate the objects that have the same hash/equality to also
|
||||
have the same object ID, since the equality check is much faster if the object
|
||||
IDs match.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.canonicalize = functools.lru_cache(128)(lambda x: x)
|
||||
|
||||
|
||||
_thread_local_state_cache = _ThreadLocalStateCache()
|
||||
|
||||
|
||||
def update_thread_local_jit_state(**kw):
|
||||
tls = jax_jit.thread_local_state()
|
||||
# After xla_client._version >= 70, the thread_local object will necessarily
|
||||
# be initialized when accessed. The following line can be removed when the
|
||||
# minimum jaxlib version is past version 70
|
||||
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
|
||||
tls.extra_jit_context = context._replace(**kw)
|
||||
tmp = context._replace(**kw)
|
||||
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
||||
|
||||
|
||||
# TODO(mattjj): remove all uses of this flag
|
||||
|
Loading…
x
Reference in New Issue
Block a user