vector.contract has been working fine so far, but it's starting to limit us.
The biggest issue is that it does not allow lhs and rhs to have different
data types, which can be supported more efficiently than when the casts are
separate operations.
PiperOrigin-RevId: 562792247
This is often useful when a kernel uses statistics tensors that are constant across
the minormost dimensions. Right now the only way to use them is to force XLA to
insert the extra dimension before the kernel, but that turns out to be very inefficient.
PiperOrigin-RevId: 561903222
This allows us to override the inferred tiling of the values, which makes it possible to
e.g. preswizzle the data into a more efficient format before the kernel.
PiperOrigin-RevId: 553402946
This lets us only perform an all-reduce once at the end of a reduction, instead of
at every step. This also bundles two small improvements, making layout inference
less strict for `vector.broadcast` and relaxing an assert in elementwise rule.
PiperOrigin-RevId: 552413179
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.
We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.
PiperOrigin-RevId: 551372130
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().