Fix jaxlib build by not exposing nvcc to pybind11. (#1819)

This commit is contained in:
Skye Wanderman-Milne 2019-12-05 18:59:29 -08:00 committed by GitHub
parent 0c0137d787
commit 7a154f71bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 80 additions and 31 deletions

View File

@ -22,6 +22,20 @@ licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "kernel_pybind11_helpers",
hdrs = ["kernel_pybind11_helpers.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":kernel_helpers",
"@pybind11",
],
)
cc_library(
name = "kernel_helpers",
hdrs = ["kernel_helpers.h"],
@ -32,7 +46,6 @@ cc_library(
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/base",
"@pybind11",
],
)
@ -106,7 +119,7 @@ pybind_extension(
module_name = "cublas_kernels",
deps = [
":gpu_kernel_helpers",
":kernel_helpers",
":kernel_pybind11_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
@ -134,7 +147,7 @@ pybind_extension(
module_name = "cusolver_kernels",
deps = [
":gpu_kernel_helpers",
":kernel_helpers",
":kernel_pybind11_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
@ -154,15 +167,10 @@ cuda_library(
name = "cuda_prng_kernels_lib",
srcs = ["cuda_prng_kernels.cu.cc"],
hdrs = ["cuda_prng_kernels.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
deps = [
":gpu_kernel_helpers",
":kernel_helpers",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
)
@ -177,7 +185,7 @@ pybind_extension(
module_name = "cuda_prng_kernels",
deps = [
":cuda_prng_kernels_lib",
":kernel_helpers",
":kernel_pybind11_helpers",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart",
"@pybind11",

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "jaxlib/kernel_pybind11_helpers.h"
namespace jax {
namespace {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "jaxlib/cuda_prng_kernels.h"
#include "jaxlib/kernel_helpers.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
@ -29,7 +29,10 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(cuda_prng_kernels, m) {
m.def("registrations", &Registrations);
m.def("cuda_threefry2x32_descriptor", &BuildCudaThreeFry2x32Descriptor);
m.def("cuda_threefry2x32_descriptor", [](std::int64_t n) {
std::string result = BuildCudaThreeFry2x32Descriptor(n);
return pybind11::bytes(result);
});
}
} // namespace

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cstddef>
#include "jaxlib/cuda_prng_kernels.h"
@ -101,8 +102,8 @@ struct ThreeFry2x32Descriptor {
std::int64_t n;
};
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptor(ThreeFry2x32Descriptor{n});
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
}
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,

View File

@ -17,13 +17,13 @@ limitations under the License.
#define JAXLIB_PRNG_KERNELS_H_
#include <cstddef>
#include <string>
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "include/pybind11/pybind11.h"
namespace jax {
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n);
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n);
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len);

View File

@ -30,7 +30,7 @@ limitations under the License.
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "jaxlib/kernel_pybind11_helpers.h"
namespace jax {
namespace {

View File

@ -18,22 +18,20 @@ limitations under the License.
#include <cstddef>
#include <stdexcept>
#include <string>
#include "absl/base/casts.h"
#include "include/pybind11/pybind11.h"
namespace jax {
// Descriptor objects are opaque host-side objects used to pass data from JAX
// to the custom kernel launched by XLA. Currently simply treat host-side
// structures as byte-strings; this is not portable across architectures. If
// portability is needed, we could switch to using a representation such as
// protocol buffers or flatbuffers.
// See kernel_pybind11_helpers.h for info on descriptor objects. We separate out
// the functionality that doesn't require pybind11 for building CUDA libraries,
// since older versions nvcc don't seem to be able to compile pybind11.
// Packs a descriptor object into a pybind11::bytes structure.
// Packs a descriptor object into a byte string.
template <typename T>
pybind11::bytes PackDescriptor(const T& descriptor) {
return pybind11::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
std::string PackDescriptorAsString(const T& descriptor) {
return std::string(absl::bit_cast<const char*>(&descriptor), sizeof(T));
}
// Unpacks a descriptor object from a byte string.
@ -45,11 +43,6 @@ const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
return absl::bit_cast<const T*>(opaque);
}
template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
} // namespace jax
#endif // JAXLIB_KERNEL_HELPERS_H_

View File

@ -0,0 +1,44 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_
#define JAXLIB_KERNEL_PYBIND11_HELPERS_H_
#include "include/pybind11/pybind11.h"
#include "jaxlib/kernel_helpers.h"
namespace jax {
// Descriptor objects are opaque host-side objects used to pass data from JAX
// to the custom kernel launched by XLA. Currently simply treat host-side
// structures as byte-strings; this is not portable across architectures. If
// portability is needed, we could switch to using a representation such as
// protocol buffers or flatbuffers.
// Packs a descriptor object into a pybind11::bytes structure.
// UnpackDescriptor() is available in kernel_helpers.h.
template <typename T>
pybind11::bytes PackDescriptor(const T& descriptor) {
return pybind11::bytes(PackDescriptorAsString(descriptor));
}
template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
} // namespace jax
#endif // JAXLIB_KERNEL_PYBIND11_HELPERS_H_