[mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib

This commit is contained in:
Sergei Lebedev 2024-12-23 13:46:25 +00:00
parent 3e7f48114c
commit 8987867faa
3 changed files with 14 additions and 0 deletions

View File

@ -28,6 +28,7 @@ package(
py_library(
name = "mosaic",
deps = [
"//jaxlib/mosaic/python:gpu_dialect",
"//jaxlib/mosaic/python:tpu_dialect",
],
)

View File

@ -83,6 +83,7 @@ setup(
'cuda/*',
'cuda/nvvm/libdevice/libdevice*',
'mosaic/*.py',
'mosaic/dialect/gpu/*.py',
'mosaic/gpu/*.so',
'mosaic/python/*.py',
'mosaic/python/*.so',

View File

@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
dst_dir=mosaic_python_dir,
src_files=[
"__main__/jaxlib/mosaic/python/layout_defs.py",
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
"__main__/jaxlib/mosaic/python/tpu.py",
],
)
@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
)
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
os.makedirs(mosaic_gpu_dir)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
dst_dir=mosaic_gpu_dir,
)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
dst_dir=mosaic_gpu_dir,
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir",
@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",