6 Commits

Author SHA1 Message Date
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
Sergei Lebedev
089651f35f Added missing BUILD dependencies for Pallas GPU lowering
PiperOrigin-RevId: 617603433
2024-03-20 13:08:58 -07:00
Sergei Lebedev
ad26ba87ca Moved Pallas GPU lowering registartion code into a separate submodule
This makes the layout similar to the one we use in Pallas TPU.

PiperOrigin-RevId: 610411536
2024-02-26 08:10:02 -08:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Sharad Vikram
d872812a35 [Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00