2022-09-22 12:26:48 -07:00
|
|
|
/* Copyright 2019 The JAX Authors.
|
2019-12-05 18:59:29 -08:00
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
#ifndef JAXLIB_KERNEL_NANOBIND_HELPERS_H_
|
|
|
|
#define JAXLIB_KERNEL_NANOBIND_HELPERS_H_
|
2019-12-05 18:59:29 -08:00
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
#include <string>
|
2024-07-10 12:08:30 -07:00
|
|
|
#include <type_traits>
|
2023-08-24 16:06:18 -07:00
|
|
|
|
|
|
|
#include "nanobind/nanobind.h"
|
2020-10-23 14:20:06 -07:00
|
|
|
#include "absl/base/casts.h"
|
2019-12-05 18:59:29 -08:00
|
|
|
#include "jaxlib/kernel_helpers.h"
|
2024-07-10 12:08:30 -07:00
|
|
|
#include "xla/ffi/api/c_api.h"
|
2024-03-29 13:05:35 -07:00
|
|
|
#include "xla/tsl/python/lib/core/numpy.h" // NOLINT
|
2019-12-05 18:59:29 -08:00
|
|
|
|
|
|
|
namespace jax {
|
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
// Caution: to use this type you must call tsl::ImportNumpy() in your module
|
|
|
|
// initialization function. Otherwise PyArray_DescrCheck will be nullptr.
|
|
|
|
class dtype : public nanobind::object {
|
|
|
|
public:
|
|
|
|
NB_OBJECT_DEFAULT(dtype, object, "dtype", PyArray_DescrCheck); // NOLINT
|
|
|
|
|
2023-09-05 10:14:26 -04:00
|
|
|
int itemsize() const { return nanobind::cast<int>(attr("itemsize")); }
|
2023-08-24 16:06:18 -07:00
|
|
|
|
|
|
|
/// Single-character code for dtype's kind.
|
|
|
|
/// For example, floating point types are 'f' and integral types are 'i'.
|
|
|
|
char kind() const { return nanobind::cast<char>(attr("kind")); }
|
|
|
|
};
|
|
|
|
|
2019-12-05 18:59:29 -08:00
|
|
|
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
|
|
|
// to the custom kernel launched by XLA. Currently simply treat host-side
|
|
|
|
// structures as byte-strings; this is not portable across architectures. If
|
|
|
|
// portability is needed, we could switch to using a representation such as
|
|
|
|
// protocol buffers or flatbuffers.
|
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
// Packs a descriptor object into a nanobind::bytes structure.
|
2019-12-05 18:59:29 -08:00
|
|
|
// UnpackDescriptor() is available in kernel_helpers.h.
|
|
|
|
template <typename T>
|
2023-08-24 16:06:18 -07:00
|
|
|
nanobind::bytes PackDescriptor(const T& descriptor) {
|
|
|
|
std::string s = PackDescriptorAsString(descriptor);
|
|
|
|
return nanobind::bytes(s.data(), s.size());
|
2019-12-05 18:59:29 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
2023-08-24 16:06:18 -07:00
|
|
|
nanobind::capsule EncapsulateFunction(T* fn) {
|
|
|
|
return nanobind::capsule(absl::bit_cast<void*>(fn),
|
2020-10-23 14:20:06 -07:00
|
|
|
"xla._CUSTOM_CALL_TARGET");
|
2019-12-05 18:59:29 -08:00
|
|
|
}
|
|
|
|
|
2024-07-10 12:08:30 -07:00
|
|
|
template <typename T>
|
|
|
|
nanobind::capsule EncapsulateFfiHandler(T* fn) {
|
|
|
|
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
|
|
|
|
"Encapsulated function must be an XLA FFI handler");
|
|
|
|
return nanobind::capsule(absl::bit_cast<void*>(fn));
|
|
|
|
}
|
|
|
|
|
2019-12-05 18:59:29 -08:00
|
|
|
} // namespace jax
|
|
|
|
|
2023-08-24 16:06:18 -07:00
|
|
|
#endif // JAXLIB_KERNEL_NANOBIND_HELPERS_H_
|