30 Commits

Author SHA1 Message Date
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Jevin Jiang
908df65a26 [Mosaic TPU] Emulate converting x16 vector to mask if mask packing is supported.
PiperOrigin-RevId: 718133639
2025-01-21 17:23:33 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Adam Paszke
74cf67df9d [Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
We previously weren't testing unsigned integer types.

PiperOrigin-RevId: 714002869
2025-01-10 04:35:38 -08:00
Adam Paszke
07f4fd3e51 [Mosaic TPU] Fix a bug in the impl of sublane broadcasts for int8 and int4
PiperOrigin-RevId: 713675029
2025-01-09 08:05:25 -08:00
Adam Paszke
5fd1b2f825 [Mosaic TPU] Add support for second minor broadcasts with packed types
PiperOrigin-RevId: 713259707
2025-01-08 05:45:02 -08:00
Jevin Jiang
2faf540203 [Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.

PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00
Adam Paszke
ad00ec1dc9 [Mosaic TPU] Guard tests for new features by the libtpu version
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Adam Paszke
45159494e5 [Pallas:TPU] Use self.pallas_call to properly handle interpret mode
PiperOrigin-RevId: 707865950
2024-12-19 04:22:48 -08:00
Jevin Jiang
3a5c4da4ef [Mosaic TPU] Support i32 vector multi reduction except cross lane.
PiperOrigin-RevId: 707708236
2024-12-18 16:49:07 -08:00
Jevin Jiang
74eca1346d [Pallas] Add version guard for non-32-bit selection in test and fix github build failure.
PiperOrigin-RevId: 707645847
2024-12-18 13:11:10 -08:00
Jevin Jiang
bf692efbfb [Mosaic TPU] Support direct cast i8 vector to mask
PiperOrigin-RevId: 707617318
2024-12-18 11:35:14 -08:00
Jevin Jiang
0fe77bc9f0 [Mosaic TPU] Support relayout for mask vector
We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished.

PiperOrigin-RevId: 697823543
2024-11-18 18:07:15 -08:00
Adam Paszke
f62941d126 [Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.
Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
2024-10-01 02:23:52 -07:00
Jevin Jiang
4a596aee1e [Mosaic TPU] Force offset to 0 when inferring input has offset out of the first tile.
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.

PiperOrigin-RevId: 680706301
2024-09-30 13:52:48 -07:00
Jevin Jiang
7e2f487ada [Mosaic TPU] Canonicalize arith.select's condition to vector if other types are vector.
This fixes the failure in elementwise rule of apply vector layout pass.

If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.

If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.

PiperOrigin-RevId: 680674560
2024-09-30 12:26:44 -07:00
Jevin Jiang
839ce9a11d [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```

Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
jax authors
a2a351f88b Fix pallas int4->int8 conversion
PiperOrigin-RevId: 666939965
2024-08-23 15:19:02 -07:00
Justin Fu
07767e81a0 [Pallas] Add support for casting to/from unsigned integer types.
PiperOrigin-RevId: 666663406
2024-08-22 23:57:17 -07:00
Justin Fu
21d57030dc [Pallas] Skip TPU-specific tests on win32
PiperOrigin-RevId: 665543351
2024-08-20 15:06:13 -07:00
Ayaka
36739e84ce Normalize "interpreter mode" to "interpret mode", and "InterpreterTest" to "InterpretTest"
This is because both "interpret mode" and "interpreter mode" occur in code, and "interpret mode" is more frequent.

PiperOrigin-RevId: 664873359
2024-08-19 10:40:22 -07:00
George Necula
7f680aaab8 [pallas] Move ops_test.py from jax_triton to jax/pallas
The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests.
Furthermore, those tests were only Google-internal, and they can be external as well.

This moves test coverage for Pallas from the jax_triton package to the Pallas core package.

A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`.

The other tests were distributed to different files in the Pallas repo, according to their purpose:

  * tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest`
  * tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there.
   * some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`.
   * some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py`

Some of this required adding sharding and hypothesis support to `ops_test.py`,
and adding TPU versions of `indexing_test.py`.

PiperOrigin-RevId: 662045774
2024-08-12 05:09:37 -07:00
Ayaka
bb160cf54e Move TPU ops test to ops_test.py
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.

PiperOrigin-RevId: 656347969
2024-07-26 04:24:13 -07:00
Ayaka
6cc09173d5 Add lowering for lax.sign 2024-07-26 10:33:42 +08:00
George Necula
6f79925d61 [pallas] Renamed platform-specific tests.
Previously I have moved the platform-specific tests in their own `tpu` and `gpu` subirectories, with the multi-platform tests at the top level in
the `tests/pallas` directory.

It turns out that `pytest` wants every test base file name to be unique when it is loading tests, and in order to be able to run `pytest tests/pallas` I sometimes had to add platform names to the test file name, even though it was already in a platform-specific directory, e.g., `gpu/gpu_ops_test.py`.

Here we delete the `tpu` and `gpu` test subdirectories and we prepend the platform name to the test file name.

Additionally, the old `tpu/pallas_call_test.py` is now renamed `tpu_pallas_test.py` (similar to the multi-platform test `pallas_test.py`).

PiperOrigin-RevId: 651029357
2024-07-10 08:23:06 -07:00