mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix jaxlib build by not exposing nvcc to pybind11. (#1819)
This commit is contained in:
parent
0c0137d787
commit
7a154f71bc
26
jaxlib/BUILD
26
jaxlib/BUILD
@ -22,6 +22,20 @@ licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "kernel_pybind11_helpers",
|
||||
hdrs = ["kernel_pybind11_helpers.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":kernel_helpers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernel_helpers",
|
||||
hdrs = ["kernel_helpers.h"],
|
||||
@ -32,7 +46,6 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
@ -106,7 +119,7 @@ pybind_extension(
|
||||
module_name = "cublas_kernels",
|
||||
deps = [
|
||||
":gpu_kernel_helpers",
|
||||
":kernel_helpers",
|
||||
":kernel_pybind11_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -134,7 +147,7 @@ pybind_extension(
|
||||
module_name = "cusolver_kernels",
|
||||
deps = [
|
||||
":gpu_kernel_helpers",
|
||||
":kernel_helpers",
|
||||
":kernel_pybind11_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -154,15 +167,10 @@ cuda_library(
|
||||
name = "cuda_prng_kernels_lib",
|
||||
srcs = ["cuda_prng_kernels.cu.cc"],
|
||||
hdrs = ["cuda_prng_kernels.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
deps = [
|
||||
":gpu_kernel_helpers",
|
||||
":kernel_helpers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
@ -177,7 +185,7 @@ pybind_extension(
|
||||
module_name = "cuda_prng_kernels",
|
||||
deps = [
|
||||
":cuda_prng_kernels_lib",
|
||||
":kernel_helpers",
|
||||
":kernel_pybind11_helpers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cudart",
|
||||
"@pybind11",
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/cuda_prng_kernels.h"
|
||||
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
@ -29,7 +29,10 @@ pybind11::dict Registrations() {
|
||||
|
||||
PYBIND11_MODULE(cuda_prng_kernels, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("cuda_threefry2x32_descriptor", &BuildCudaThreeFry2x32Descriptor);
|
||||
m.def("cuda_threefry2x32_descriptor", [](std::int64_t n) {
|
||||
std::string result = BuildCudaThreeFry2x32Descriptor(n);
|
||||
return pybind11::bytes(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
#include "jaxlib/cuda_prng_kernels.h"
|
||||
@ -101,8 +102,8 @@ struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n;
|
||||
};
|
||||
|
||||
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptor(ThreeFry2x32Descriptor{n});
|
||||
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
|
||||
}
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
||||
#define JAXLIB_PRNG_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n);
|
||||
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n);
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len);
|
||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
@ -18,22 +18,20 @@ limitations under the License.
|
||||
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
||||
// to the custom kernel launched by XLA. Currently simply treat host-side
|
||||
// structures as byte-strings; this is not portable across architectures. If
|
||||
// portability is needed, we could switch to using a representation such as
|
||||
// protocol buffers or flatbuffers.
|
||||
// See kernel_pybind11_helpers.h for info on descriptor objects. We separate out
|
||||
// the functionality that doesn't require pybind11 for building CUDA libraries,
|
||||
// since older versions nvcc don't seem to be able to compile pybind11.
|
||||
|
||||
// Packs a descriptor object into a pybind11::bytes structure.
|
||||
// Packs a descriptor object into a byte string.
|
||||
template <typename T>
|
||||
pybind11::bytes PackDescriptor(const T& descriptor) {
|
||||
return pybind11::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
|
||||
std::string PackDescriptorAsString(const T& descriptor) {
|
||||
return std::string(absl::bit_cast<const char*>(&descriptor), sizeof(T));
|
||||
}
|
||||
|
||||
// Unpacks a descriptor object from a byte string.
|
||||
@ -45,11 +43,6 @@ const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
|
||||
return absl::bit_cast<const T*>(opaque);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
pybind11::capsule EncapsulateFunction(T* fn) {
|
||||
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_KERNEL_HELPERS_H_
|
||||
|
44
jaxlib/kernel_pybind11_helpers.h
Normal file
44
jaxlib/kernel_pybind11_helpers.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2019 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_
|
||||
#define JAXLIB_KERNEL_PYBIND11_HELPERS_H_
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
||||
// to the custom kernel launched by XLA. Currently simply treat host-side
|
||||
// structures as byte-strings; this is not portable across architectures. If
|
||||
// portability is needed, we could switch to using a representation such as
|
||||
// protocol buffers or flatbuffers.
|
||||
|
||||
// Packs a descriptor object into a pybind11::bytes structure.
|
||||
// UnpackDescriptor() is available in kernel_helpers.h.
|
||||
template <typename T>
|
||||
pybind11::bytes PackDescriptor(const T& descriptor) {
|
||||
return pybind11::bytes(PackDescriptorAsString(descriptor));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
pybind11::capsule EncapsulateFunction(T* fn) {
|
||||
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_KERNEL_PYBIND11_HELPERS_H_
|
Loading…
x
Reference in New Issue
Block a user