This allows us to override the inferred tiling of the values, which makes it possible to
e.g. preswizzle the data into a more efficient format before the kernel.
PiperOrigin-RevId: 553402946
This lets us only perform an all-reduce once at the end of a reduction, instead of
at every step. This also bundles two small improvements, making layout inference
less strict for `vector.broadcast` and relaxing an assert in elementwise rule.
PiperOrigin-RevId: 552413179
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.
We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.
PiperOrigin-RevId: 551372130
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().