Justin Fu
d129438548
[Mosaic GPU] Prototype of a warp-specialized pipeline emitter for Mosaic GPU.
...
PiperOrigin-RevId: 708010809
2024-12-19 13:28:58 -08:00
Adam Paszke
23000a3842
Always suppress the differing_executors Hypothesis health check
...
It's only relevant to notify about potential key collisions in the example
database, but we explicitly disable it, so it doesn't matter.
PiperOrigin-RevId: 707914664
2024-12-19 08:00:53 -08:00
Adam Paszke
ad00ec1dc9
[Mosaic TPU] Guard tests for new features by the libtpu version
...
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Adam Paszke
45159494e5
[Pallas:TPU] Use self.pallas_call to properly handle interpret mode
...
PiperOrigin-RevId: 707865950
2024-12-19 04:22:48 -08:00
Ayaka
613a0cde0a
[Pallas] Fix lowering tests for reduction ops
...
Remove unnecessary skip statements. Also added tests for bf16 types.
PiperOrigin-RevId: 707739536
2024-12-18 18:43:15 -08:00
Jevin Jiang
3a5c4da4ef
[Mosaic TPU] Support i32 vector multi reduction except cross lane.
...
PiperOrigin-RevId: 707708236
2024-12-18 16:49:07 -08:00
Jevin Jiang
74eca1346d
[Pallas] Add version guard for non-32-bit selection in test and fix github build failure.
...
PiperOrigin-RevId: 707645847
2024-12-18 13:11:10 -08:00
Jevin Jiang
bf692efbfb
[Mosaic TPU] Support direct cast i8 vector to mask
...
PiperOrigin-RevId: 707617318
2024-12-18 11:35:14 -08:00
Adam Paszke
3c52ca6e1c
Use the right health check suppression
...
The last change actually failed to fix the Cloud TPU tests.
This time I got a Cloud TPU for myself and verified it works.
PiperOrigin-RevId: 707534853
2024-12-18 06:55:12 -08:00
Sergei Lebedev
5e9cfce296
[pallas:mosaic_gpu] Added a lowering rule for the general lax.while_loop_p
...
Previously, our lowering only handled while loops which can be rewritten as
for loops.
PiperOrigin-RevId: 707533279
2024-12-18 06:49:32 -08:00
Adam Paszke
d95b95b405
[Mosaic TPU] Add support for exp, exp2 and log in bf16 on TPUv6
...
PiperOrigin-RevId: 707520511
2024-12-18 05:59:12 -08:00
Justin Fu
7e96914e61
Add Pallas Philox implementation.
...
Implemented in the same style as the threefry kernel. Philox is roughly 2x faster than the existing JAX Threefry implementation in both runtime and compile time.
PiperOrigin-RevId: 707276043
2024-12-17 15:37:08 -08:00
Adam Paszke
e1f037bffd
Reduce the max_examples for splash attention tests
...
They are too expensive to run all the time. We have 25 test methods, so it's ok
to drop the number of examples each one of them covers.
PiperOrigin-RevId: 707124466
2024-12-17 09:06:19 -08:00
Adam Paszke
3fc237125b
Make gmm TPU kernel tests significantly cheaper
...
We were testing lots of very similar cases that did not really help a lot with coverage.
PiperOrigin-RevId: 707115030
2024-12-17 08:38:59 -08:00
Adam Paszke
0ec902df83
[Mosaic TPU] Add support for bf16 abs
...
PiperOrigin-RevId: 707041113
2024-12-17 04:31:56 -08:00
Adam Paszke
16c44e51ac
Ignore a false-positive Hypothesis health check
...
See https://github.com/HypothesisWorks/hypothesis/issues/3733 for details,
but in short Hypothesis does not like parameterized tests under pytest.
PiperOrigin-RevId: 706677075
2024-12-16 05:48:46 -08:00
Gleb Pobudzey
e92ca9bbae
Use boolean values for partial mask blocks in the splash attention kernel.
...
The values are guaranteed to be 0 or 1 since we create this array ourselves when processing the masks into a MaskInfo object.
PiperOrigin-RevId: 705252534
2024-12-11 14:59:30 -08:00
Gleb Pobudzey
20236f1083
Increase shard count after adding more tests
...
PiperOrigin-RevId: 705146601
2024-12-11 10:08:50 -08:00
Ayaka
13ce51785d
[Pallas] Remove grid=1
in tests
...
Remove `grid=1` in tests because it's the same as not specifying `grid`.
PiperOrigin-RevId: 705077047
2024-12-11 05:56:32 -08:00
Ayaka
e88b578356
[Pallas TPU] Add WeirdOp
to TPU dialect and add lowering for lax.is_finite
...
PiperOrigin-RevId: 704888940
2024-12-10 16:38:04 -08:00
jax authors
3ca9f14107
Merge pull request #25361 from Rifur13:regression
...
PiperOrigin-RevId: 704885039
2024-12-10 16:25:36 -08:00
Jacob Burnim
1c1a17e0f0
Only run tpu_all_gather_test on tpu_v5e_4x2
...
PiperOrigin-RevId: 704871583
2024-12-10 15:42:42 -08:00
Tzu-Wei Sung
e418e88321
[Pallas] Add non-square pl.dot test cases.
...
PiperOrigin-RevId: 704788500
2024-12-10 11:38:28 -08:00
Dan Foreman-Mackey
978d35f697
Fix expected exception type in pallas grad tests.
...
PiperOrigin-RevId: 704408603
2024-12-09 14:02:07 -08:00
Gleb Pobudzey
e1e174fbc4
Adding more tests for multi-head attention
2024-12-09 20:49:06 +00:00
Ayaka
9c98c0cbbf
[Pallas TPU] Improve lowerings for boolean comparison operations
...
The error when negating a boolean value (https://github.com/jax-ml/jax/issues/24243 ) has been fixed, so we can lower the boolean comparison operations using boolean algebra instead of using the previous workaround.
Besides, the original tests uses `allclose` on boolean arrays, which is wrong. I have changed them to `assertArraysEqual`.
PiperOrigin-RevId: 704294742
2024-12-09 08:23:51 -08:00
Chris Jones
3ec55c7723
[pallas:triton] Add support for DotAlgorithmPreset
precision
arguments to dot
.
...
PiperOrigin-RevId: 704208558
2024-12-09 02:52:47 -08:00
Justin Fu
641a1d53ce
[Pallas] Add support for run_state to cost estimator.
...
PiperOrigin-RevId: 703543961
2024-12-06 10:36:02 -08:00
Adam Paszke
eda7506d6b
[Pallas MGPU] Disable XLA:GPU autotuning in attention tests
...
We don't care about performance of the reference impl, we only use it for
correctness testing. More importantly, it works around a deadlock at compile
time that sometimes happens when testing large batch sizes.
PiperOrigin-RevId: 703521029
2024-12-06 09:19:08 -08:00
jax authors
84f3f99217
[pallas] fix jumble test flakiness
...
* Enable interpret mode in tests
* Ensure that the kernel is run multiple times where weve seen data corruption
* Use masked comparison - prior comparison was reading garbage data as we were basically relying on past behavior of how uninitialized memory was behaving.
* This was being hidden by a cache, where the interpret test, which always has 0.0 for uninitialized memory was being hit first, where TPU does not have the same behavior.
PiperOrigin-RevId: 703272002
2024-12-05 15:31:23 -08:00
George Necula
3f5f3e1c47
[export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
...
This is because the underlying Triton IR does not guarantee compatibility.
PiperOrigin-RevId: 703127711
2024-12-05 08:42:41 -08:00
Adam Paszke
d5ead570bb
[Mosaic TPU] Add support for modeling loads/stores and fix minor issues in model extraction
...
PiperOrigin-RevId: 703102072
2024-12-05 07:06:19 -08:00
jax authors
db97d7aa3d
Merge pull request #25199 from Rifur13:save_residuals
...
PiperOrigin-RevId: 702824842
2024-12-04 12:45:28 -08:00
Sergei Lebedev
12b45b3235
[pallas:mosaic_gpu] emit_pipeline
no longer ignores transforms
...
PiperOrigin-RevId: 702726201
2024-12-04 07:59:42 -08:00
Peter Hawkins
2ac2692457
Disable backwards compatibility test for Triton IR.
...
Triton doesn't promise backwards compatibility of its IR, so the test is misguided: it is testing a property that isn't true. If we wanted to promise backwards compatibility, we would need to use a versioned IR across the boundary.
PiperOrigin-RevId: 702725103
2024-12-04 07:55:40 -08:00
Christos Perivolaropoulos
3895e0372c
[mgpu_pallas] Allow loading scalars or indexing arrays from gmem using splat.
...
PiperOrigin-RevId: 702704429
2024-12-04 06:36:23 -08:00
Ayaka
2dae81a8ed
[Pallas TPU] Enable test for jnp.logical_not
because it's now supported
...
PiperOrigin-RevId: 702439876
2024-12-03 12:56:32 -08:00
Adam Paszke
0bb68f6ad2
[Pallas:MGPU] Add tests for attention with non-trivial batch size
...
PiperOrigin-RevId: 702280467
2024-12-03 03:58:15 -08:00
Justin Fu
784ebeabc8
[Mosaic GPU] Automatically squash a >3D logical grid into a 3D physical CUDA grid.
...
PiperOrigin-RevId: 702013252
2024-12-02 10:32:29 -08:00
Adam Paszke
aff7714dc0
[Pallas:MGPU] Fix an overly strict precision requirement in tests
...
They started failing after we allowed LLVM to perform contractions of
adds and muls, but the difference is tiny.
PiperOrigin-RevId: 701961845
2024-12-02 07:34:18 -08:00
Gleb Pobudzey
a4e742d2fe
Save residuals in the decode attention pallas kernel
2024-12-02 15:09:16 +00:00
Christos Perivolaropoulos
c3c21c7462
[mgpu_pallas] Better support for unsigned integers and floats in iota.
...
PiperOrigin-RevId: 701307324
2024-11-29 09:39:29 -08:00
Justin Fu
6e72592be6
[Pallas] Fix float -> int casting on Triton backend.
...
PiperOrigin-RevId: 700761545
2024-11-27 11:32:58 -08:00
Christos Perivolaropoulos
8477580d95
[mgpu pallas] Layout iota operation.
...
PiperOrigin-RevId: 700711177
2024-11-27 08:34:10 -08:00
Christos Perivolaropoulos
f828f2d7d0
[mgpu] Pointwise min
...
PiperOrigin-RevId: 700175724
2024-11-25 19:13:51 -08:00
Christos Perivolaropoulos
c5dc980db8
[mgpu/pallas_mgpu] Pointwise tanh support
...
PiperOrigin-RevId: 700158250
2024-11-25 17:56:11 -08:00
Christos Perivolaropoulos
ef7df1ae7c
[pallas_mgpu] Allow trees (eg tuples) to be returned from cond_p expressions.
...
PiperOrigin-RevId: 700136799
2024-11-25 16:36:43 -08:00
Justin Fu
73fa0f48cb
[Pallas] Deprecate dictionary compiler_params in favor of dataclass.
...
PiperOrigin-RevId: 699057658
2024-11-21 23:34:32 -08:00
Sergei Lebedev
1efef6bf6b
[pallas:mosaic_gpu] emit_pipeline
now correctly supports BlockSpec
s in GMEM
...
This is necessary to replace the pipelining logic in the lowering with
`emit_pipeline`.
PiperOrigin-RevId: 698858380
2024-11-21 11:38:43 -08:00
Peter Buchlovsky
2178ed2fa4
[pallas] Add more test cases for Triton bitcast_convert_type lowering rule.
...
PiperOrigin-RevId: 698818103
2024-11-21 09:52:04 -08:00