mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params. In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals. ``` name old cpu/op new cpu/op delta jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5) name old time/op new time/op delta jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 645491650