mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

Users are expected to call `pltpu.commit_smem` manually instead. PiperOrigin-RevId: 691724662
49 lines
705 B
ReStructuredText
49 lines
705 B
ReStructuredText
``jax.experimental.pallas.mosaic_gpu`` module
|
|
=============================================
|
|
|
|
.. automodule:: jax.experimental.pallas.mosaic_gpu
|
|
|
|
Classes
|
|
-------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
Barrier
|
|
GPUBlockSpec
|
|
GPUCompilerParams
|
|
GPUMemorySpace
|
|
Layout
|
|
SwizzleTransform
|
|
TilingTransform
|
|
TransposeTransform
|
|
WGMMAAccumulatorRef
|
|
|
|
Functions
|
|
---------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
barrier_arrive
|
|
barrier_wait
|
|
commit_smem
|
|
copy_gmem_to_smem
|
|
copy_smem_to_gmem
|
|
emit_pipeline
|
|
layout_cast
|
|
set_max_registers
|
|
wait_smem_to_gmem
|
|
wgmma
|
|
wgmma_wait
|
|
|
|
Aliases
|
|
-------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
ACC
|
|
GMEM
|
|
SMEM
|