13 Commits

Author SHA1 Message Date
Jevin Jiang
5d8f5c20fa [Mosaic] Remove multiple results check in apply layout.
PiperOrigin-RevId: 555320679
2023-08-09 17:17:25 -07:00
Adam Paszke
16c33df3cf [Mosaic] Relax return type checks for vector.contract
PiperOrigin-RevId: 553898552
2023-08-04 13:34:12 -07:00
Jevin Jiang
df4e0cbb73 [Mosaic] Add support for relayout from (8, 128) to (1, 128) with 32-bit data.
PiperOrigin-RevId: 553606429
2023-08-03 15:24:15 -07:00
Adam Paszke
a184b5e4af [Mosaic] Add a reinterpret cast for memrefs
This allows us to override the inferred tiling of the values, which makes it possible to
e.g. preswizzle the data into a more efficient format before the kernel.

PiperOrigin-RevId: 553402946
2023-08-03 01:50:19 -07:00
Adam Paszke
716f4f8119 [Mosaic] Allow users to opt out of window prefetching
PiperOrigin-RevId: 552797922
2023-08-01 07:36:59 -07:00
Jevin Jiang
6e37c4202d [Mosaic] Add support for sublane strided store.
PiperOrigin-RevId: 552595663
2023-07-31 14:40:27 -07:00
Jevin Jiang
9d62d867bc [Mosaic] Add support for sublane strided load.
PiperOrigin-RevId: 552581319
2023-07-31 13:47:43 -07:00
Adam Paszke
5c952a1c04 [Mosaic] Swap the order of all-reduce and elementwise ops
This lets us only perform an all-reduce once at the end of a reduction, instead of
at every step. This also bundles two small improvements, making layout inference
less strict for `vector.broadcast` and relaxing an assert in elementwise rule.

PiperOrigin-RevId: 552413179
2023-07-31 02:01:57 -07:00
Skye Wanderman-Milne
a03d6e6613 Move _tpu_ext.cc to jaxlib/mlir/_mlir_libs and set RPATH correctly
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.

We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.

PiperOrigin-RevId: 551372130
2023-07-26 18:25:17 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
jax authors
32cbc3678d Integrate LLVM at llvm/llvm-project@571c1292b6
Updates LLVM usage to match
[571c1292b693](https://github.com/llvm/llvm-project/commit/571c1292b693)

PiperOrigin-RevId: 550071080
2023-07-21 15:56:28 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Sharad Vikram
3d556b7a19 Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00