483 Commits

Author SHA1 Message Date
Jevin Jiang
d4b564a263 [Mosaic] Support relayout from (1,128) to (8,128) when dst.offset is (0, 0).
PiperOrigin-RevId: 564882618
2023-09-12 17:35:09 -07:00
Benjamin Kramer
a26125c49e Integrate LLVM at llvm/llvm-project@c1796be93f
Updates LLVM usage to match
[c1796be93fe5](https://github.com/llvm/llvm-project/commit/c1796be93fe5)

PiperOrigin-RevId: 564842806
2023-09-12 15:00:09 -07:00
Jevin Jiang
801cbef011 [Mosaic] Use strided load to load one entire row more efficiently.
PiperOrigin-RevId: 564831610
2023-09-12 14:19:35 -07:00
Jevin Jiang
9d8642122a [Mosaic] Use strided store to store one row.
PiperOrigin-RevId: 564821813
2023-09-12 13:56:58 -07:00
Adam Paszke
dbb0e8f214 [Mosaic] Add a pass for instantiating memory spaces
PiperOrigin-RevId: 564723473
2023-09-12 08:05:26 -07:00
Sharad Vikram
d0c4c9b3fe [Pallas] Add support for scoped allocations to Pallas TPU
PiperOrigin-RevId: 563580548
2023-09-07 16:41:01 -07:00
jax authors
bfd79b84e4 Support relayout of tiles in register when the layout tiling changes.
PiperOrigin-RevId: 563570338
2023-09-07 16:00:44 -07:00
Tomás Longeri
a6eed40f24 [MOSAIC] apply_vector_layout C++ rewrite (3) applyLayoutOp and relayout
PiperOrigin-RevId: 563556815
2023-09-07 15:08:30 -07:00
Jevin Jiang
8b700fa75d [Mosaic] Support relayout from (1, 128) to (8, 128).
PiperOrigin-RevId: 563534657
2023-09-07 13:49:44 -07:00
Adam Paszke
1ea5b26524 [Mosaic][Pallas] Extend support for slicing and concatenation
... and squeezing. Those are useful when implementing batched matrix multiplication kernels.

PiperOrigin-RevId: 563381169
2023-09-07 03:45:01 -07:00
Adam Paszke
2f92517349 [Mosaic] Add a tpu.matmul op
vector.contract has been working fine so far, but it's starting to limit us.
The biggest issue is that it does not allow lhs and rhs to have different
data types, which can be supported more efficiently than when the casts are
separate operations.

PiperOrigin-RevId: 562792247
2023-09-05 08:51:41 -07:00
Adam Paszke
73275aa9ef [Mosaic] Add support for inserting a new lane dimension
This is often useful when a kernel uses statistics tensors that are constant across
the minormost dimensions. Right now the only way to use them is to force XLA to
insert the extra dimension before the kernel, but that turns out to be very inefficient.

PiperOrigin-RevId: 561903222
2023-09-01 03:02:02 -07:00
Tomás Longeri
24c3a9dc79 [MOSAIC] apply_vector_layout C++ rewrite (2) No-op pass and flag to use it instead of Python
PiperOrigin-RevId: 561697585
2023-08-31 10:42:03 -07:00
Tomás Longeri
d02b59e410 [MOSAIC] apply_vector_layout C++ rewrite (1) VectorLayout functions
PiperOrigin-RevId: 561237760
2023-08-29 22:57:13 -07:00
Jevin Jiang
046bcc0ad9 [Mosaic] Add missing headers in linalg vectorization.
PiperOrigin-RevId: 560865251
2023-08-28 17:39:47 -07:00
Tomás Longeri
c3e624ad8a [Mosaic] Fix assert
PiperOrigin-RevId: 560756391
2023-08-28 10:50:23 -07:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Christian Sigg
b4f7928a81 [NFC] Explicitly set dialects to usePropertiesForAttributes=0 in preparation for https://reviews.llvm.org/D158581 (flipping the default to 1) to land.
This allows us to switch dialects to use properties one by one.

PiperOrigin-RevId: 559751065
2023-08-24 07:52:19 -07:00
Adam Paszke
45428c1375 [Mosaic] Add an annotation explaining signedness semantics of tpu.pack_elements
PiperOrigin-RevId: 559747952
2023-08-24 07:32:59 -07:00
Sharad Vikram
f08df0f0c3 [Pallas] Add Mosaic lowering for pow2
PiperOrigin-RevId: 559291137
2023-08-22 19:42:17 -07:00
Jevin Jiang
5d8f5c20fa [Mosaic] Remove multiple results check in apply layout.
PiperOrigin-RevId: 555320679
2023-08-09 17:17:25 -07:00
Adam Paszke
16c33df3cf [Mosaic] Relax return type checks for vector.contract
PiperOrigin-RevId: 553898552
2023-08-04 13:34:12 -07:00
Jevin Jiang
df4e0cbb73 [Mosaic] Add support for relayout from (8, 128) to (1, 128) with 32-bit data.
PiperOrigin-RevId: 553606429
2023-08-03 15:24:15 -07:00
Adam Paszke
a184b5e4af [Mosaic] Add a reinterpret cast for memrefs
This allows us to override the inferred tiling of the values, which makes it possible to
e.g. preswizzle the data into a more efficient format before the kernel.

PiperOrigin-RevId: 553402946
2023-08-03 01:50:19 -07:00
Adam Paszke
716f4f8119 [Mosaic] Allow users to opt out of window prefetching
PiperOrigin-RevId: 552797922
2023-08-01 07:36:59 -07:00
Jevin Jiang
6e37c4202d [Mosaic] Add support for sublane strided store.
PiperOrigin-RevId: 552595663
2023-07-31 14:40:27 -07:00
Jevin Jiang
9d62d867bc [Mosaic] Add support for sublane strided load.
PiperOrigin-RevId: 552581319
2023-07-31 13:47:43 -07:00
Adam Paszke
5c952a1c04 [Mosaic] Swap the order of all-reduce and elementwise ops
This lets us only perform an all-reduce once at the end of a reduction, instead of
at every step. This also bundles two small improvements, making layout inference
less strict for `vector.broadcast` and relaxing an assert in elementwise rule.

PiperOrigin-RevId: 552413179
2023-07-31 02:01:57 -07:00
Skye Wanderman-Milne
a03d6e6613 Move _tpu_ext.cc to jaxlib/mlir/_mlir_libs and set RPATH correctly
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.

We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.

PiperOrigin-RevId: 551372130
2023-07-26 18:25:17 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
jax authors
32cbc3678d Integrate LLVM at llvm/llvm-project@571c1292b6
Updates LLVM usage to match
[571c1292b693](https://github.com/llvm/llvm-project/commit/571c1292b693)

PiperOrigin-RevId: 550071080
2023-07-21 15:56:28 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Sharad Vikram
3d556b7a19 Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00