Fix cusparse kernel build.

The build_wheel.py script was copying the wrong module.

In addition the CUDA stubs from the TF repo were missing a number of cusparse symbols. The updated TF includes the correct stubs.
This commit is contained in:
Peter Hawkins 2021-04-27 20:28:30 -04:00
parent 4e490894e3
commit a0c96b5ca5
2 changed files with 5 additions and 5 deletions

View File

@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "ce033795d4d58ecb92a0188d6881c29d5a126e62cd21dd1882eb902e2b5ac226",
strip_prefix = "tensorflow-559047cb46f6805dfc50cba6d91b5a1e2d8d1b68",
sha256 = "4e07806f8786baa478d38e89974fbcf8bf9e41839f29b05f210fce02427b39d1",
strip_prefix = "tensorflow-b0e85b5b3859d060a42364c79fe664b07299a0e9",
urls = [
"https://github.com/tensorflow/tensorflow/archive/559047cb46f6805dfc50cba6d91b5a1e2d8d1b68.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/b0e85b5b3859d060a42364c79fe664b07299a0e9.tar.gz",
],
)

View File

@ -193,8 +193,8 @@ def prepare_wheel(sources_path):
if r.Rlocation("__main__/jaxlib/rocblas_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocblas_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocsolver.py"))
if r.Rlocation("__main__/jaxlib/cusparse.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.so"))
if r.Rlocation("__main__/jaxlib/cusparse_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))