mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[Mosaic GPU] Ensure that the dialect module can be loaded successfully.
This requires that the file providing the bindings has the same name as the dialect it defines, since dialect search looks for a module path of the form `<prefix>.<dialect namespace>`. PiperOrigin-RevId: 693241875
This commit is contained in:
parent
a913fbf2fd
commit
63e59c5fd7
@ -121,7 +121,7 @@ import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
|
||||
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401
|
||||
|
||||
try:
|
||||
import jaxlib.mosaic.python.gpu as mosaic_gpu_dialect # pytype: disable=import-error
|
||||
import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error
|
||||
except ImportError:
|
||||
# TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36.
|
||||
# Jaxlib doesn't contain Mosaic GPU dialect bindings.
|
||||
|
@ -20,7 +20,7 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
py_library(
|
||||
name = "gpu_dialect",
|
||||
srcs = [
|
||||
"gpu.py",
|
||||
"mosaic_gpu.py",
|
||||
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -12,7 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Python bindings for the MLIR Mosaic GPU dialect."""
|
||||
"""Python bindings for the MLIR Mosaic GPU dialect.
|
||||
|
||||
Note: this file *must* be called `mosaic_gpu.py`, in order to match the dialect
|
||||
name. Otherwise, MLIR is unable to find the module during dialect search.
|
||||
"""
|
||||
|
||||
# ruff: noqa: F401
|
||||
# ruff: noqa: F403
|
@ -21,6 +21,9 @@ from jax._src.lib.mlir import ir
|
||||
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
|
||||
|
||||
|
||||
_cext = mgpu._cext if mgpu is not None else None
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@ -40,6 +43,9 @@ class DialectTest(parameterized.TestCase):
|
||||
self.enter_context(ir.Location.unknown())
|
||||
self.module = ir.Module.create()
|
||||
|
||||
def test_dialect_module_is_loaded(self):
|
||||
self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu"))
|
||||
|
||||
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
@ -62,6 +68,8 @@ class DialectTest(parameterized.TestCase):
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
arrival_count=1)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
self.assertIsInstance(self.module.body.operations[0],
|
||||
mgpu.InitializeBarrierOp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user