Include compute capability 8.0 SASS in jaxlib wheels.

Drop compute capability 6.1 to avoid growing the wheel size.

Also fix an unrelated build error due to a gcc warning in boringssl.
This commit is contained in:
Peter Hawkins 2021-12-14 14:27:19 -05:00
parent 0404dbdd29
commit 66823d1392
4 changed files with 11 additions and 5 deletions

View File

@ -91,6 +91,7 @@ build:linux --config=posix
# Workaround for gcc 10+ warnings related to upb.
# See https://github.com/tensorflow/tensorflow/issues/39467
build:linux --copt=-Wno-stringop-truncation
build:linux --copt=-Wno-array-parameter
build:macos --config=posix

View File

@ -22,6 +22,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
using JAX custom AD APIs ({jax-issue}`#7839`).
## jaxlib 0.1.76 (Unreleased)
* New features
* Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS.
## jaxlib 0.1.75 (Dec 8, 2021)
* New features:

View File

@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "e6e68270b7fb0dda26656f4985b3d118b5a9d3fa60433ced42ce2fc5676716d5",
strip_prefix = "tensorflow-d9b47d5722fc08c5d06afba9f63177de266801f5",
sha256 = "5f6bb29818543ff6722e51147b8386ba365e152b508dc7b9ad920df0b7125101",
strip_prefix = "tensorflow-dd57f5328f37a81197b0dadd052e05c9d9461b16",
urls = [
"https://github.com/tensorflow/tensorflow/archive/d9b47d5722fc08c5d06afba9f63177de266801f5.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/dd57f5328f37a81197b0dadd052e05c9d9461b16.tar.gz",
],
)

View File

@ -352,7 +352,7 @@ def main():
parser,
"enable_nccl",
default=True,
help_str="Should we build with NCCL enabled? Has non effect for non-CUDA "
help_str="Should we build with NCCL enabled? Has no effect for non-CUDA "
"builds.")
add_boolean_argument(
parser,
@ -377,7 +377,7 @@ def main():
help="CUDNN version, e.g., 8")
parser.add_argument(
"--cuda_compute_capabilities",
default="3.5,5.2,6.0,6.1,7.0",
default="3.5,5.2,6.0,7.0,8.0",
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",