mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working on Pallas TPU can fill them in later.
This commit is contained in:
parent
351187d9da
commit
46e65b5982
41
docs/jax.experimental.pallas.mosaic_gpu.rst
Normal file
41
docs/jax.experimental.pallas.mosaic_gpu.rst
Normal file
@ -0,0 +1,41 @@
|
||||
``jax.experimental.pallas.mosaic_gpu`` module
|
||||
=============================================
|
||||
|
||||
.. automodule:: jax.experimental.pallas.mosaic_gpu
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Barrier
|
||||
GPUBlockSpec
|
||||
GPUCompilerParams
|
||||
GPUMemorySpace
|
||||
TilingTransform
|
||||
TransposeTransform
|
||||
WGMMAAccumulatorRef
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
copy_gmem_to_smem
|
||||
copy_smem_to_gmem
|
||||
wait_barrier
|
||||
wait_smem_to_gmem
|
||||
wgmma
|
||||
wgmma_wait
|
||||
|
||||
Aliases
|
||||
-------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ACC
|
||||
GMEM
|
||||
SMEM
|
@ -3,6 +3,16 @@
|
||||
|
||||
.. automodule:: jax.experimental.pallas
|
||||
|
||||
Backends
|
||||
--------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.pallas.mosaic_gpu
|
||||
jax.experimental.pallas.triton
|
||||
jax.experimental.pallas.tpu
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
@ -36,7 +46,11 @@ Functions
|
||||
atomic_min
|
||||
atomic_or
|
||||
atomic_xchg
|
||||
|
||||
atomic_xor
|
||||
broadcast_to
|
||||
debug_print
|
||||
|
||||
dot
|
||||
max_contiguous
|
||||
multiple_of
|
||||
run_scoped
|
||||
when
|
||||
|
16
docs/jax.experimental.pallas.tpu.rst
Normal file
16
docs/jax.experimental.pallas.tpu.rst
Normal file
@ -0,0 +1,16 @@
|
||||
``jax.experimental.pallas.tpu`` module
|
||||
======================================
|
||||
|
||||
.. automodule:: jax.experimental.pallas.tpu
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
22
docs/jax.experimental.pallas.triton.rst
Normal file
22
docs/jax.experimental.pallas.triton.rst
Normal file
@ -0,0 +1,22 @@
|
||||
``jax.experimental.pallas.triton`` module
|
||||
=========================================
|
||||
|
||||
.. automodule:: jax.experimental.pallas.triton
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
TritonCompilerParams
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
approx_tanh
|
||||
debug_barrier
|
||||
elementwise_inline_asm
|
@ -16,17 +16,17 @@ Experimental Modules
|
||||
|
||||
jax.experimental.array_api
|
||||
jax.experimental.checkify
|
||||
jax.experimental.pjit
|
||||
jax.experimental.sparse
|
||||
jax.experimental.jet
|
||||
jax.experimental.custom_partitioning
|
||||
jax.experimental.multihost_utils
|
||||
jax.experimental.compilation_cache
|
||||
jax.experimental.custom_partitioning
|
||||
jax.experimental.jet
|
||||
jax.experimental.key_reuse
|
||||
jax.experimental.mesh_utils
|
||||
jax.experimental.multihost_utils
|
||||
jax.experimental.pallas
|
||||
jax.experimental.pjit
|
||||
jax.experimental.serialize_executable
|
||||
jax.experimental.shard_map
|
||||
jax.experimental.pallas
|
||||
jax.experimental.sparse
|
||||
|
||||
Experimental APIs
|
||||
-----------------
|
||||
|
@ -59,8 +59,11 @@ class GPUCompilerParams(pallas_core.CompilerParams):
|
||||
|
||||
|
||||
class GPUMemorySpace(enum.Enum):
|
||||
#: Global memory.
|
||||
GMEM = "gmem"
|
||||
#: Shared memory.
|
||||
SMEM = "smem"
|
||||
#: Registers.
|
||||
REGS = "regs"
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -23,6 +23,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
|
||||
@ -31,5 +32,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
|
||||
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
|
||||
GMEM = GPUMemorySpace.GMEM
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`.
|
||||
SMEM = GPUMemorySpace.SMEM
|
||||
|
Loading…
x
Reference in New Issue
Block a user