mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 21:36:06 +00:00

Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python. The general pattern is that each extension now has two parts: * xyz_kernels.{cc, h} — the C++ parts * xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects. There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry. PiperOrigin-RevId: 394460343
94 lines
3.2 KiB
C++
94 lines
3.2 KiB
C++
/* Copyright 2019 Google LLC
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <algorithm>
|
|
#include <stdexcept>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "absl/strings/str_format.h"
|
|
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
|
#include "third_party/gpus/cuda/include/cuda.h"
|
|
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
|
#include "jaxlib/cublas_kernels.h"
|
|
#include "jaxlib/kernel_pybind11_helpers.h"
|
|
#include "include/pybind11/numpy.h"
|
|
#include "include/pybind11/pybind11.h"
|
|
#include "include/pybind11/stl.h"
|
|
|
|
namespace jax {
|
|
namespace {
|
|
|
|
namespace py = pybind11;
|
|
|
|
// Converts a NumPy dtype to a Type.
|
|
CublasType DtypeToCublasType(const py::dtype& np_type) {
|
|
static auto* types =
|
|
new absl::flat_hash_map<std::pair<char, int>, CublasType>({
|
|
{{'f', 4}, CublasType::F32},
|
|
{{'f', 8}, CublasType::F64},
|
|
{{'c', 8}, CublasType::C64},
|
|
{{'c', 16}, CublasType::C128},
|
|
});
|
|
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
|
if (it == types->end()) {
|
|
throw std::invalid_argument(
|
|
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
// Returns the descriptor for a TrsmBatched operation.
|
|
std::pair<size_t, py::bytes> BuildTrsmBatchedDescriptor(
|
|
const py::dtype& dtype, int batch, int m, int n, bool left_side, bool lower,
|
|
bool trans_a, bool conj_a, bool unit_diagonal) {
|
|
size_t size = batch * sizeof(void*);
|
|
TrsmBatchedDescriptor desc;
|
|
desc.type = DtypeToCublasType(dtype);
|
|
desc.batch = batch;
|
|
desc.m = m;
|
|
desc.n = n;
|
|
desc.side = left_side ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
|
desc.uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
|
desc.trans = trans_a ? (conj_a ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N;
|
|
desc.diag = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
|
return {size, PackDescriptor(desc)};
|
|
}
|
|
|
|
// Returns the descriptor for a GetrfBatched operation.
|
|
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
|
int b, int n) {
|
|
CublasType type = DtypeToCublasType(dtype);
|
|
size_t size = b * sizeof(void*);
|
|
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
|
|
}
|
|
|
|
py::dict Registrations() {
|
|
py::dict dict;
|
|
dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
|
|
dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
|
return dict;
|
|
}
|
|
|
|
PYBIND11_MODULE(_cublas, m) {
|
|
m.def("registrations", &Registrations);
|
|
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
|
|
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace jax
|