mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10752 from cloudhan:fix-windows-copy
PiperOrigin-RevId: 449585033
This commit is contained in:
commit
8c694fd008
@ -61,21 +61,17 @@ def _is_windows():
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
|
||||
def _copy_so(src_file, dst_dir, dst_filename=None):
|
||||
src_filename = os.path.basename(src_file)
|
||||
if not dst_filename:
|
||||
if _is_windows() and src_filename.endswith(".so"):
|
||||
dst_filename = src_filename[:-3] + ".pyd"
|
||||
else:
|
||||
dst_filename = src_filename
|
||||
dst_file = os.path.join(dst_dir, dst_filename)
|
||||
if _is_windows():
|
||||
shutil.copyfile(src_file, dst_file)
|
||||
else:
|
||||
shutil.copy(src_file, dst_file)
|
||||
pyext = "pyd" if _is_windows() else "so"
|
||||
|
||||
|
||||
def _copy_normal(src_file, dst_dir, dst_filename=None):
|
||||
def exists(src_file):
|
||||
return r.Rlocation(src_file) is not None
|
||||
|
||||
|
||||
def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True):
|
||||
if from_runfiles:
|
||||
src_file = r.Rlocation(src_file)
|
||||
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
if _is_windows():
|
||||
@ -84,13 +80,6 @@ def _copy_normal(src_file, dst_dir, dst_filename=None):
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def copy_file(src_file, dst_dir, dst_filename=None):
|
||||
if src_file.endswith(".so"):
|
||||
_copy_so(src_file, dst_dir, dst_filename=dst_filename)
|
||||
else:
|
||||
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
|
||||
|
||||
|
||||
_XLA_EXTENSION_STUBS = [
|
||||
"__init__.pyi",
|
||||
"jax_jit.pyi",
|
||||
@ -182,58 +171,44 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
|
||||
|
||||
verify_mac_libraries_dont_reference_chkstack()
|
||||
copy_to_jaxlib(r.Rlocation("__main__/build/LICENSE.txt"),
|
||||
dst_dir=sources_path)
|
||||
copy_file(r.Rlocation("__main__/jaxlib/setup.py"), dst_dir=sources_path)
|
||||
copy_file(r.Rlocation("__main__/jaxlib/setup.cfg"), dst_dir=sources_path)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/init.py"),
|
||||
dst_filename="__init__.py")
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cpu_feature_guard.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/lapack.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_lapack.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mhlo_helpers.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/gpu_prng.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/gpu_linalg.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/gpu_solver.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/gpu_sparse.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
|
||||
copy_file("__main__/build/LICENSE.txt", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/setup.cfg", dst_dir=sources_path)
|
||||
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/lapack.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/version.py")
|
||||
|
||||
pyext = "pyd" if _is_windows() else "so"
|
||||
|
||||
cuda_dir = os.path.join(jaxlib_dir, "cuda")
|
||||
if r.Rlocation(f"__main__/jaxlib/cuda/_cusolver.{pyext}") is not None:
|
||||
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
|
||||
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
|
||||
os.makedirs(libdevice_dir)
|
||||
copy_file(r.Rlocation("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"),
|
||||
dst_dir=libdevice_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/cuda/_cusolver.{pyext}"),
|
||||
dst_dir=cuda_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/cuda/_cublas.{pyext}"),
|
||||
dst_dir=cuda_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}"),
|
||||
dst_dir=cuda_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}"),
|
||||
dst_dir=cuda_dir)
|
||||
copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cusolver.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cublas.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}", dst_dir=cuda_dir)
|
||||
rocm_dir = os.path.join(jaxlib_dir, "rocm")
|
||||
if r.Rlocation(f"__main__/jaxlib/rocm/_hipsolver.{pyext}") is not None:
|
||||
if exists(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"):
|
||||
os.makedirs(rocm_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"),
|
||||
dst_dir=rocm_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/rocm/_hipblas.{pyext}"),
|
||||
dst_dir=rocm_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}"),
|
||||
dst_dir=rocm_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/rocm/_hip_prng.{pyext}"),
|
||||
dst_dir=rocm_dir)
|
||||
if r.Rlocation(f"__main__/jaxlib/cuda/_cusparse.{pyext}") is not None:
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/cuda/_cusparse.{pyext}"),
|
||||
dst_dir=cuda_dir)
|
||||
if r.Rlocation(f"__main__/jaxlib/rocm/_hipsparse.{pyext}") is not None:
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"),
|
||||
dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hipsolver.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hipblas.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hip_prng.{pyext}", dst_dir=rocm_dir)
|
||||
if exists(f"__main__/jaxlib/cuda/_cusparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/cuda/_cusparse.{pyext}", dst_dir=cuda_dir)
|
||||
if exists(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/rocm/_hipsparse.{pyext}", dst_dir=rocm_dir)
|
||||
|
||||
|
||||
mlir_dir = os.path.join(jaxlib_dir, "mlir")
|
||||
@ -244,41 +219,40 @@ def prepare_wheel(sources_path):
|
||||
os.makedirs(mlir_dialects_dir)
|
||||
os.makedirs(mlir_libs_dir)
|
||||
os.makedirs(mlir_transforms_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/ir.py"), dst_dir=mlir_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/passmanager.py"), dst_dir=mlir_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/transforms/__init__.py"), dst_dir=mlir_transforms_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_ods_common.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_func_ops_ext.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_func_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/sparse_tensor.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/builtin.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/chlo.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/mhlo.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/func.py"),
|
||||
dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir)
|
||||
copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir)
|
||||
copy_file("__main__/jaxlib/mlir/transforms/__init__.py", dst_dir=mlir_transforms_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir)
|
||||
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"__main__/jaxlib/mlir/_mlir_libs/_mlirTransforms.{pyext}"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation(f"org_tensorflow/tensorflow/compiler/xla/python/xla_extension.{pyext}"))
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirTransforms.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(f"org_tensorflow/tensorflow/compiler/xla/python/xla_extension.{pyext}")
|
||||
if _is_windows():
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"), dst_dir=mlir_libs_dir)
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
|
||||
elif _is_mac():
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib"), dst_dir=mlir_libs_dir)
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir)
|
||||
else:
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so"), dst_dir=mlir_libs_dir)
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
patch_copy_xla_client_py(jaxlib_dir)
|
||||
|
||||
if not _is_windows():
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
|
||||
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
|
||||
patch_copy_tpu_client_py(jaxlib_dir)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user