When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free.
PiperOrigin-RevId: 645169743
Refactoring only, NFC intended.
* add types to more places.
* don't unpack PjitInfo positionally, since it's a 23-tuple and that seems rather error prone.
* change _infer_params to produce a new PjitParams NamedTuple, rather than having callers unpack a 9-tuple positionally.
* inline _pjit_jaxpr into its caller, since it only has one caller and the wrapper doesn't really clarify anything.
* note the return type of transformation_with_aux is a Callable.
PiperOrigin-RevId: 645068326
The llvm.expect intrinsic puts the loop at the end of the program, allowing
the whole barrier to be compiled to a test_wait + predicated branch that is
immediately followed by the continuation. This seems to make the happy path
a little faster which can help reduce the barrier overhead for compute-bound
kernels.
PiperOrigin-RevId: 645007019
Passing .c_str() to the ParseFromString can lead to inconsistent behavior when c string is not properly null terminated. This diff initializes an std::string explicitly by providing a size of a buffer to be parsed.
PiperOrigin-RevId: 644979040
With this change we reach state of the art performance (as far as I can tell)
of 50%+ TC util for head_dim 128 and 256.
I also added a little tuning harness to try out different block sizes.
PiperOrigin-RevId: 644927079
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 644845277
This flag is useful - user can increase the internal scratch size to get more efficient pl.roll op and relayout in Mosaic.
PiperOrigin-RevId: 644576369
We will choose the best solution based on the size of internal scratch memory.
- Sol 1: Convert dynamic roll to Log(N) static ops
- Sol 2: Static Store + Dynamic Load with internal scratch
PiperOrigin-RevId: 644509328