We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.
PiperOrigin-RevId: 708143852
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.
PiperOrigin-RevId: 680706301
This fixes the failure in elementwise rule of apply vector layout pass.
If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.
If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.
PiperOrigin-RevId: 680674560
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests.
Furthermore, those tests were only Google-internal, and they can be external as well.
This moves test coverage for Pallas from the jax_triton package to the Pallas core package.
A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`.
The other tests were distributed to different files in the Pallas repo, according to their purpose:
* tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest`
* tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there.
* some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`.
* some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py`
Some of this required adding sharding and hypothesis support to `ops_test.py`,
and adding TPU versions of `indexing_test.py`.
PiperOrigin-RevId: 662045774
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.
PiperOrigin-RevId: 656347969
Previously I have moved the platform-specific tests in their own `tpu` and `gpu` subirectories, with the multi-platform tests at the top level in
the `tests/pallas` directory.
It turns out that `pytest` wants every test base file name to be unique when it is loading tests, and in order to be able to run `pytest tests/pallas` I sometimes had to add platform names to the test file name, even though it was already in a platform-specific directory, e.g., `gpu/gpu_ops_test.py`.
Here we delete the `tpu` and `gpu` test subdirectories and we prepend the platform name to the test file name.
Additionally, the old `tpu/pallas_call_test.py` is now renamed `tpu_pallas_test.py` (similar to the multi-platform test `pallas_test.py`).
PiperOrigin-RevId: 651029357