Sergei Lebedev
498e81ab10
Pallas now exclusively uses XLA for compiling kernels on GPU
...
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.
PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Sergei Lebedev
f74f4ed48b
Removed unnecessary BUILD dependencies from :ops_test
...
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
Sergei Lebedev
089651f35f
Added missing BUILD dependencies for Pallas GPU lowering
...
PiperOrigin-RevId: 617603433
2024-03-20 13:08:58 -07:00
Sergei Lebedev
ad26ba87ca
Moved Pallas GPU lowering registartion code into a separate submodule
...
This makes the layout similar to the one we use in Pallas TPU.
PiperOrigin-RevId: 610411536
2024-02-26 08:10:02 -08:00
Richard Levasseur
f891cbf64b
Load Python rules from rules_python
...
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Sharad Vikram
d872812a35
[Pallas] Upstream pallas to JAX
...
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00