diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 4db2b9fb4..f716bebc9 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -22,6 +22,20 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +cc_library( + name = "kernel_pybind11_helpers", + hdrs = ["kernel_pybind11_helpers.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":kernel_helpers", + "@pybind11", + ], +) + cc_library( name = "kernel_helpers", hdrs = ["kernel_helpers.h"], @@ -32,7 +46,6 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", - "@pybind11", ], ) @@ -106,7 +119,7 @@ pybind_extension( module_name = "cublas_kernels", deps = [ ":gpu_kernel_helpers", - ":kernel_helpers", + ":kernel_pybind11_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", @@ -134,7 +147,7 @@ pybind_extension( module_name = "cusolver_kernels", deps = [ ":gpu_kernel_helpers", - ":kernel_helpers", + ":kernel_pybind11_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", @@ -154,15 +167,10 @@ cuda_library( name = "cuda_prng_kernels_lib", srcs = ["cuda_prng_kernels.cu.cc"], hdrs = ["cuda_prng_kernels.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], deps = [ ":gpu_kernel_helpers", ":kernel_helpers", "@local_config_cuda//cuda:cuda_headers", - "@pybind11", ], ) @@ -177,7 +185,7 @@ pybind_extension( module_name = "cuda_prng_kernels", deps = [ ":cuda_prng_kernels_lib", - ":kernel_helpers", + ":kernel_pybind11_helpers", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudart", "@pybind11", diff --git a/jaxlib/cublas.cc b/jaxlib/cublas.cc index f96f04987..48db7b8b6 100644 --- a/jaxlib/cublas.cc +++ b/jaxlib/cublas.cc @@ -29,7 +29,7 @@ limitations under the License. #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" #include "jaxlib/gpu_kernel_helpers.h" -#include "jaxlib/kernel_helpers.h" +#include "jaxlib/kernel_pybind11_helpers.h" namespace jax { namespace { diff --git a/jaxlib/cuda_prng_kernels.cc b/jaxlib/cuda_prng_kernels.cc index 3f9eb9c97..b1c57bdb1 100644 --- a/jaxlib/cuda_prng_kernels.cc +++ b/jaxlib/cuda_prng_kernels.cc @@ -15,7 +15,7 @@ limitations under the License. #include "jaxlib/cuda_prng_kernels.h" -#include "jaxlib/kernel_helpers.h" +#include "jaxlib/kernel_pybind11_helpers.h" #include "include/pybind11/pybind11.h" namespace jax { @@ -29,7 +29,10 @@ pybind11::dict Registrations() { PYBIND11_MODULE(cuda_prng_kernels, m) { m.def("registrations", &Registrations); - m.def("cuda_threefry2x32_descriptor", &BuildCudaThreeFry2x32Descriptor); + m.def("cuda_threefry2x32_descriptor", [](std::int64_t n) { + std::string result = BuildCudaThreeFry2x32Descriptor(n); + return pybind11::bytes(result); + }); } } // namespace diff --git a/jaxlib/cuda_prng_kernels.cu.cc b/jaxlib/cuda_prng_kernels.cu.cc index f318c584e..eeedc7d39 100644 --- a/jaxlib/cuda_prng_kernels.cu.cc +++ b/jaxlib/cuda_prng_kernels.cu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "jaxlib/cuda_prng_kernels.h" @@ -101,8 +102,8 @@ struct ThreeFry2x32Descriptor { std::int64_t n; }; -pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n) { - return PackDescriptor(ThreeFry2x32Descriptor{n}); +std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) { + return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); } void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, diff --git a/jaxlib/cuda_prng_kernels.h b/jaxlib/cuda_prng_kernels.h index 406eeed1d..6512bee59 100644 --- a/jaxlib/cuda_prng_kernels.h +++ b/jaxlib/cuda_prng_kernels.h @@ -17,13 +17,13 @@ limitations under the License. #define JAXLIB_PRNG_KERNELS_H_ #include +#include #include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "include/pybind11/pybind11.h" namespace jax { -pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n); +std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n); void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); diff --git a/jaxlib/cusolver.cc b/jaxlib/cusolver.cc index 161cadcbc..8d1c49f26 100644 --- a/jaxlib/cusolver.cc +++ b/jaxlib/cusolver.cc @@ -30,7 +30,7 @@ limitations under the License. #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" #include "jaxlib/gpu_kernel_helpers.h" -#include "jaxlib/kernel_helpers.h" +#include "jaxlib/kernel_pybind11_helpers.h" namespace jax { namespace { diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index ce818b4c0..e2c7ba19c 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -18,22 +18,20 @@ limitations under the License. #include #include +#include #include "absl/base/casts.h" -#include "include/pybind11/pybind11.h" namespace jax { -// Descriptor objects are opaque host-side objects used to pass data from JAX -// to the custom kernel launched by XLA. Currently simply treat host-side -// structures as byte-strings; this is not portable across architectures. If -// portability is needed, we could switch to using a representation such as -// protocol buffers or flatbuffers. +// See kernel_pybind11_helpers.h for info on descriptor objects. We separate out +// the functionality that doesn't require pybind11 for building CUDA libraries, +// since older versions nvcc don't seem to be able to compile pybind11. -// Packs a descriptor object into a pybind11::bytes structure. +// Packs a descriptor object into a byte string. template -pybind11::bytes PackDescriptor(const T& descriptor) { - return pybind11::bytes(absl::bit_cast(&descriptor), sizeof(T)); +std::string PackDescriptorAsString(const T& descriptor) { + return std::string(absl::bit_cast(&descriptor), sizeof(T)); } // Unpacks a descriptor object from a byte string. @@ -45,11 +43,6 @@ const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) { return absl::bit_cast(opaque); } -template -pybind11::capsule EncapsulateFunction(T* fn) { - return pybind11::capsule(absl::bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); -} - } // namespace jax #endif // JAXLIB_KERNEL_HELPERS_H_ diff --git a/jaxlib/kernel_pybind11_helpers.h b/jaxlib/kernel_pybind11_helpers.h new file mode 100644 index 000000000..8319610d6 --- /dev/null +++ b/jaxlib/kernel_pybind11_helpers.h @@ -0,0 +1,44 @@ +/* Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_ +#define JAXLIB_KERNEL_PYBIND11_HELPERS_H_ + +#include "include/pybind11/pybind11.h" +#include "jaxlib/kernel_helpers.h" + +namespace jax { + +// Descriptor objects are opaque host-side objects used to pass data from JAX +// to the custom kernel launched by XLA. Currently simply treat host-side +// structures as byte-strings; this is not portable across architectures. If +// portability is needed, we could switch to using a representation such as +// protocol buffers or flatbuffers. + +// Packs a descriptor object into a pybind11::bytes structure. +// UnpackDescriptor() is available in kernel_helpers.h. +template +pybind11::bytes PackDescriptor(const T& descriptor) { + return pybind11::bytes(PackDescriptorAsString(descriptor)); +} + +template +pybind11::capsule EncapsulateFunction(T* fn) { + return pybind11::capsule(absl::bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +} // namespace jax + +#endif // JAXLIB_KERNEL_PYBIND11_HELPERS_H_