[Mosaic GPU] Ensure that the dialect module can be loaded successfully.

This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.

PiperOrigin-RevId: 693241875
This commit is contained in:
Benjamin Chetioui 2024-11-05 00:46:40 -08:00 committed by jax authors
parent a913fbf2fd
commit 63e59c5fd7
4 changed files with 15 additions and 3 deletions

View File

@ -121,7 +121,7 @@ import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401
try:
import jaxlib.mosaic.python.gpu as mosaic_gpu_dialect # pytype: disable=import-error
import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error
except ImportError:
# TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36.
# Jaxlib doesn't contain Mosaic GPU dialect bindings.

View File

@ -20,7 +20,7 @@ load("@rules_python//python:defs.bzl", "py_library")
py_library(
name = "gpu_dialect",
srcs = [
"gpu.py",
"mosaic_gpu.py",
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py",
],
visibility = ["//visibility:public"],

View File

@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python bindings for the MLIR Mosaic GPU dialect."""
"""Python bindings for the MLIR Mosaic GPU dialect.
Note: this file *must* be called `mosaic_gpu.py`, in order to match the dialect
name. Otherwise, MLIR is unable to find the module during dialect search.
"""
# ruff: noqa: F401
# ruff: noqa: F403

View File

@ -21,6 +21,9 @@ from jax._src.lib.mlir import ir
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
_cext = mgpu._cext if mgpu is not None else None
config.parse_flags_with_absl()
@ -40,6 +43,9 @@ class DialectTest(parameterized.TestCase):
self.enter_context(ir.Location.unknown())
self.module = ir.Module.create()
def test_dialect_module_is_loaded(self):
self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu"))
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
@ -62,6 +68,8 @@ class DialectTest(parameterized.TestCase):
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
arrival_count=1)
self.assertTrue(self.module.operation.verify())
self.assertIsInstance(self.module.body.operations[0],
mgpu.InitializeBarrierOp)
if __name__ == "__main__":