[pallas] Added API docs for Triton and Mosaic GPU backends

I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
This commit is contained in:
Sergei Lebedev 2024-10-08 11:13:04 +01:00
parent 351187d9da
commit 46e65b5982
7 changed files with 107 additions and 8 deletions

View File

@ -0,0 +1,41 @@
``jax.experimental.pallas.mosaic_gpu`` module
=============================================
.. automodule:: jax.experimental.pallas.mosaic_gpu
Classes
-------
.. autosummary::
:toctree: _autosummary
Barrier
GPUBlockSpec
GPUCompilerParams
GPUMemorySpace
TilingTransform
TransposeTransform
WGMMAAccumulatorRef
Functions
---------
.. autosummary::
:toctree: _autosummary
copy_gmem_to_smem
copy_smem_to_gmem
wait_barrier
wait_smem_to_gmem
wgmma
wgmma_wait
Aliases
-------
.. autosummary::
:toctree: _autosummary
ACC
GMEM
SMEM

View File

@ -3,6 +3,16 @@
.. automodule:: jax.experimental.pallas
Backends
--------
.. toctree::
:maxdepth: 1
jax.experimental.pallas.mosaic_gpu
jax.experimental.pallas.triton
jax.experimental.pallas.tpu
Classes
-------
@ -36,7 +46,11 @@ Functions
atomic_min
atomic_or
atomic_xchg
atomic_xor
broadcast_to
debug_print
dot
max_contiguous
multiple_of
run_scoped
when

View File

@ -0,0 +1,16 @@
``jax.experimental.pallas.tpu`` module
======================================
.. automodule:: jax.experimental.pallas.tpu
Classes
-------
.. autosummary::
:toctree: _autosummary
Functions
---------
.. autosummary::
:toctree: _autosummary

View File

@ -0,0 +1,22 @@
``jax.experimental.pallas.triton`` module
=========================================
.. automodule:: jax.experimental.pallas.triton
Classes
-------
.. autosummary::
:toctree: _autosummary
TritonCompilerParams
Functions
---------
.. autosummary::
:toctree: _autosummary
approx_tanh
debug_barrier
elementwise_inline_asm

View File

@ -16,17 +16,17 @@ Experimental Modules
jax.experimental.array_api
jax.experimental.checkify
jax.experimental.pjit
jax.experimental.sparse
jax.experimental.jet
jax.experimental.custom_partitioning
jax.experimental.multihost_utils
jax.experimental.compilation_cache
jax.experimental.custom_partitioning
jax.experimental.jet
jax.experimental.key_reuse
jax.experimental.mesh_utils
jax.experimental.multihost_utils
jax.experimental.pallas
jax.experimental.pjit
jax.experimental.serialize_executable
jax.experimental.shard_map
jax.experimental.pallas
jax.experimental.sparse
Experimental APIs
-----------------

View File

@ -59,8 +59,11 @@ class GPUCompilerParams(pallas_core.CompilerParams):
class GPUMemorySpace(enum.Enum):
#: Global memory.
GMEM = "gmem"
#: Shared memory.
SMEM = "smem"
#: Registers.
REGS = "regs"
def __str__(self) -> str:

View File

@ -23,6 +23,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
@ -31,5 +32,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
GMEM = GPUMemorySpace.GMEM
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`.
SMEM = GPUMemorySpace.SMEM