From 64433435ffd020c0db3ae4ee02a569d933f84ab6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 2 Jan 2025 15:55:03 +0000 Subject: [PATCH] Fix OSS build for the Mosaic GPU dialect --- jaxlib/mlir/_mlir_libs/BUILD.bazel | 6 ++++-- jaxlib/mosaic/dialect/gpu/BUILD | 23 +++++++++++++++++++++++ jaxlib/mosaic/python/mosaic_gpu.py | 3 ++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 817c23a11..4486a4c4f 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -158,9 +158,10 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + ":jaxlib_mlir_capi_shared_library", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers", "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", ], ) @@ -380,6 +381,7 @@ cc_library( name = "jaxlib_mlir_capi_objects", deps = [ "//jaxlib/mosaic:tpu_dialect_capi_objects", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects", "@llvm-project//mlir:CAPIArithObjects", "@llvm-project//mlir:CAPIGPUObjects", "@llvm-project//mlir:CAPIIRObjects", diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 681ee708e..50ea58104 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -215,3 +215,26 @@ cc_library( "@llvm-project//mlir:CAPIIR", ], ) + +# Header-only target, used when using the C API from a separate shared library. +cc_library( + name = "gpu_dialect_capi_headers", + hdrs = DIALECT_CAPI_HEADERS, + deps = [ + ":mosaic_gpu_inc_gen", + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + +# Alwayslink target, used when exporting the C API from a shared library. +cc_library( + name = "gpu_dialect_capi_objects", + srcs = DIALECT_CAPI_SOURCES, + hdrs = DIALECT_CAPI_HEADERS, + deps = [ + ":mosaic_gpu", + ":mosaic_gpu_inc_gen", + "@llvm-project//mlir:CAPIIRObjects", + ], + alwayslink = True, +) diff --git a/jaxlib/mosaic/python/mosaic_gpu.py b/jaxlib/mosaic/python/mosaic_gpu.py index f99f53cfd..cce2909be 100644 --- a/jaxlib/mosaic/python/mosaic_gpu.py +++ b/jaxlib/mosaic/python/mosaic_gpu.py @@ -33,4 +33,5 @@ except ImportError: from mlir.dialects._ods_common import _cext # type: ignore[import-not-found] -_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python") +# Add the parent module to the search prefix +_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])