mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib
This commit is contained in:
parent
3e7f48114c
commit
8987867faa
@ -28,6 +28,7 @@ package(
|
||||
py_library(
|
||||
name = "mosaic",
|
||||
deps = [
|
||||
"//jaxlib/mosaic/python:gpu_dialect",
|
||||
"//jaxlib/mosaic/python:tpu_dialect",
|
||||
],
|
||||
)
|
||||
|
@ -83,6 +83,7 @@ setup(
|
||||
'cuda/*',
|
||||
'cuda/nvvm/libdevice/libdevice*',
|
||||
'mosaic/*.py',
|
||||
'mosaic/dialect/gpu/*.py',
|
||||
'mosaic/gpu/*.so',
|
||||
'mosaic/python/*.py',
|
||||
'mosaic/python/*.so',
|
||||
|
@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
dst_dir=mosaic_python_dir,
|
||||
src_files=[
|
||||
"__main__/jaxlib/mosaic/python/layout_defs.py",
|
||||
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
|
||||
"__main__/jaxlib/mosaic/python/tpu.py",
|
||||
],
|
||||
)
|
||||
@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
|
||||
)
|
||||
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
|
||||
os.makedirs(mosaic_gpu_dir)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
|
||||
dst_dir=mosaic_gpu_dir,
|
||||
)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
|
||||
dst_dir=mosaic_gpu_dir,
|
||||
)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir",
|
||||
@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
|
||||
|
Loading…
x
Reference in New Issue
Block a user