mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #19377 from superbobry:main
PiperOrigin-RevId: 598866324
This commit is contained in:
commit
abe820c1e8
@ -69,6 +69,8 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
"//jaxlib/triton:compat",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -115,6 +115,10 @@ setup(
|
||||
'mlir/_mlir_libs/*.pyd',
|
||||
'mlir/_mlir_libs/*.py',
|
||||
'rocm/*',
|
||||
'triton/*.py',
|
||||
'triton/*.pyi',
|
||||
'triton/*.pyd',
|
||||
'triton/*.so',
|
||||
],
|
||||
'jaxlib.xla_extension': ['*.pyi'],
|
||||
},
|
||||
|
@ -332,6 +332,24 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
],
|
||||
)
|
||||
|
||||
triton_dir = jaxlib_dir / "triton"
|
||||
copy_runfiles(
|
||||
dst_dir=triton_dir,
|
||||
src_files=[
|
||||
"__main__/jaxlib/triton/__init__.py",
|
||||
"__main__/jaxlib/triton/compat.py",
|
||||
"__main__/jaxlib/triton/dialect.py",
|
||||
f"__main__/jaxlib/triton/_triton_ext.{pyext}",
|
||||
"__main__/jaxlib/triton/_triton_ext.pyi",
|
||||
],
|
||||
)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/triton/_triton_enum_gen.py", dst_dir=triton_dir
|
||||
)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir
|
||||
)
|
||||
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
|
Loading…
x
Reference in New Issue
Block a user