Merge pull request #19377 from superbobry:main

PiperOrigin-RevId: 598866324
This commit is contained in:
jax authors 2024-01-16 09:33:29 -08:00
commit abe820c1e8
3 changed files with 24 additions and 0 deletions

View File

@ -69,6 +69,8 @@ py_library_providing_imports_info(
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic",
"//jaxlib/triton",
"//jaxlib/triton:compat",
],
)

View File

@ -115,6 +115,10 @@ setup(
'mlir/_mlir_libs/*.pyd',
'mlir/_mlir_libs/*.py',
'rocm/*',
'triton/*.py',
'triton/*.pyi',
'triton/*.pyd',
'triton/*.so',
],
'jaxlib.xla_extension': ['*.pyi'],
},

View File

@ -332,6 +332,24 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
],
)
triton_dir = jaxlib_dir / "triton"
copy_runfiles(
dst_dir=triton_dir,
src_files=[
"__main__/jaxlib/triton/__init__.py",
"__main__/jaxlib/triton/compat.py",
"__main__/jaxlib/triton/dialect.py",
f"__main__/jaxlib/triton/_triton_ext.{pyext}",
"__main__/jaxlib/triton/_triton_ext.pyi",
],
)
patch_copy_mlir_import(
"__main__/jaxlib/triton/_triton_enum_gen.py", dst_dir=triton_dir
)
patch_copy_mlir_import(
"__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir
)
tmpdir = None
sources_path = args.sources_path