mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working on Pallas TPU can fill them in later.
This commit is contained in:
parent
351187d9da
commit
46e65b5982
41
docs/jax.experimental.pallas.mosaic_gpu.rst
Normal file
41
docs/jax.experimental.pallas.mosaic_gpu.rst
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
``jax.experimental.pallas.mosaic_gpu`` module
|
||||||
|
=============================================
|
||||||
|
|
||||||
|
.. automodule:: jax.experimental.pallas.mosaic_gpu
|
||||||
|
|
||||||
|
Classes
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Barrier
|
||||||
|
GPUBlockSpec
|
||||||
|
GPUCompilerParams
|
||||||
|
GPUMemorySpace
|
||||||
|
TilingTransform
|
||||||
|
TransposeTransform
|
||||||
|
WGMMAAccumulatorRef
|
||||||
|
|
||||||
|
Functions
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
copy_gmem_to_smem
|
||||||
|
copy_smem_to_gmem
|
||||||
|
wait_barrier
|
||||||
|
wait_smem_to_gmem
|
||||||
|
wgmma
|
||||||
|
wgmma_wait
|
||||||
|
|
||||||
|
Aliases
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
ACC
|
||||||
|
GMEM
|
||||||
|
SMEM
|
@ -3,6 +3,16 @@
|
|||||||
|
|
||||||
.. automodule:: jax.experimental.pallas
|
.. automodule:: jax.experimental.pallas
|
||||||
|
|
||||||
|
Backends
|
||||||
|
--------
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
jax.experimental.pallas.mosaic_gpu
|
||||||
|
jax.experimental.pallas.triton
|
||||||
|
jax.experimental.pallas.tpu
|
||||||
|
|
||||||
Classes
|
Classes
|
||||||
-------
|
-------
|
||||||
|
|
||||||
@ -36,7 +46,11 @@ Functions
|
|||||||
atomic_min
|
atomic_min
|
||||||
atomic_or
|
atomic_or
|
||||||
atomic_xchg
|
atomic_xchg
|
||||||
|
atomic_xor
|
||||||
|
broadcast_to
|
||||||
debug_print
|
debug_print
|
||||||
|
dot
|
||||||
|
max_contiguous
|
||||||
|
multiple_of
|
||||||
run_scoped
|
run_scoped
|
||||||
|
when
|
||||||
|
16
docs/jax.experimental.pallas.tpu.rst
Normal file
16
docs/jax.experimental.pallas.tpu.rst
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
``jax.experimental.pallas.tpu`` module
|
||||||
|
======================================
|
||||||
|
|
||||||
|
.. automodule:: jax.experimental.pallas.tpu
|
||||||
|
|
||||||
|
Classes
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Functions
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
22
docs/jax.experimental.pallas.triton.rst
Normal file
22
docs/jax.experimental.pallas.triton.rst
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
``jax.experimental.pallas.triton`` module
|
||||||
|
=========================================
|
||||||
|
|
||||||
|
.. automodule:: jax.experimental.pallas.triton
|
||||||
|
|
||||||
|
Classes
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
TritonCompilerParams
|
||||||
|
|
||||||
|
Functions
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
approx_tanh
|
||||||
|
debug_barrier
|
||||||
|
elementwise_inline_asm
|
@ -16,17 +16,17 @@ Experimental Modules
|
|||||||
|
|
||||||
jax.experimental.array_api
|
jax.experimental.array_api
|
||||||
jax.experimental.checkify
|
jax.experimental.checkify
|
||||||
jax.experimental.pjit
|
|
||||||
jax.experimental.sparse
|
|
||||||
jax.experimental.jet
|
|
||||||
jax.experimental.custom_partitioning
|
|
||||||
jax.experimental.multihost_utils
|
|
||||||
jax.experimental.compilation_cache
|
jax.experimental.compilation_cache
|
||||||
|
jax.experimental.custom_partitioning
|
||||||
|
jax.experimental.jet
|
||||||
jax.experimental.key_reuse
|
jax.experimental.key_reuse
|
||||||
jax.experimental.mesh_utils
|
jax.experimental.mesh_utils
|
||||||
|
jax.experimental.multihost_utils
|
||||||
|
jax.experimental.pallas
|
||||||
|
jax.experimental.pjit
|
||||||
jax.experimental.serialize_executable
|
jax.experimental.serialize_executable
|
||||||
jax.experimental.shard_map
|
jax.experimental.shard_map
|
||||||
jax.experimental.pallas
|
jax.experimental.sparse
|
||||||
|
|
||||||
Experimental APIs
|
Experimental APIs
|
||||||
-----------------
|
-----------------
|
||||||
|
@ -59,8 +59,11 @@ class GPUCompilerParams(pallas_core.CompilerParams):
|
|||||||
|
|
||||||
|
|
||||||
class GPUMemorySpace(enum.Enum):
|
class GPUMemorySpace(enum.Enum):
|
||||||
|
#: Global memory.
|
||||||
GMEM = "gmem"
|
GMEM = "gmem"
|
||||||
|
#: Shared memory.
|
||||||
SMEM = "smem"
|
SMEM = "smem"
|
||||||
|
#: Registers.
|
||||||
REGS = "regs"
|
REGS = "regs"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
@ -23,6 +23,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
|
|||||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
||||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform
|
from jax._src.pallas.mosaic_gpu.core import TilingTransform
|
||||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
|
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
|
||||||
|
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
|
||||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
|
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
|
||||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
|
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
|
||||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
|
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
|
||||||
@ -31,5 +32,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
|
|||||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma
|
from jax._src.pallas.mosaic_gpu.primitives import wgmma
|
||||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
|
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
|
||||||
|
|
||||||
|
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
|
||||||
GMEM = GPUMemorySpace.GMEM
|
GMEM = GPUMemorySpace.GMEM
|
||||||
|
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`.
|
||||||
SMEM = GPUMemorySpace.SMEM
|
SMEM = GPUMemorySpace.SMEM
|
||||||
|
Loading…
x
Reference in New Issue
Block a user