2022-09-22 12:26:48 -07:00
|
|
|
/* Copyright 2021 The JAX Authors.
|
2022-02-15 17:54:02 +00:00
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
2024-07-10 12:08:30 -07:00
|
|
|
#include "jaxlib/gpu/linalg_kernels.h"
|
2022-02-15 17:54:02 +00:00
|
|
|
|
2024-07-10 12:08:30 -07:00
|
|
|
#include <cstddef>
|
2024-05-02 08:11:09 -07:00
|
|
|
#include <string>
|
2024-07-10 12:08:30 -07:00
|
|
|
#include <string_view>
|
2022-05-24 14:20:53 -07:00
|
|
|
|
2024-06-27 09:24:15 -07:00
|
|
|
#include "absl/status/status.h"
|
|
|
|
#include "absl/status/statusor.h"
|
|
|
|
#include "absl/strings/str_format.h"
|
2024-07-10 12:08:30 -07:00
|
|
|
#include "jaxlib/ffi_helpers.h"
|
2022-10-25 07:23:07 -07:00
|
|
|
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
|
|
|
#include "jaxlib/gpu/vendor.h"
|
2024-07-10 12:08:30 -07:00
|
|
|
#include "jaxlib/kernel_helpers.h"
|
2024-05-02 08:11:09 -07:00
|
|
|
#include "xla/ffi/api/ffi.h"
|
2024-07-10 12:08:30 -07:00
|
|
|
#include "xla/service/custom_call_status.h"
|
2022-02-15 17:54:02 +00:00
|
|
|
|
|
|
|
namespace jax {
|
2022-10-25 07:23:07 -07:00
|
|
|
namespace JAX_GPU_NAMESPACE {
|
2022-02-15 17:54:02 +00:00
|
|
|
|
2024-05-02 08:11:09 -07:00
|
|
|
namespace ffi = xla::ffi;
|
|
|
|
|
2024-07-10 12:08:30 -07:00
|
|
|
namespace {
|
|
|
|
absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers,
|
|
|
|
const char* opaque, std::size_t opaque_len) {
|
|
|
|
auto s = UnpackDescriptor<CholeskyUpdateDescriptor>(opaque, opaque_len);
|
|
|
|
JAX_RETURN_IF_ERROR(s.status());
|
|
|
|
const CholeskyUpdateDescriptor& d = **s;
|
2024-08-26 17:03:27 -07:00
|
|
|
JAX_RETURN_IF_ERROR(
|
|
|
|
JAX_AS_STATUS(LaunchCholeskyUpdateKernel(stream, buffers, d)));
|
2024-07-10 12:08:30 -07:00
|
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
|
|
|
return absl::OkStatus();
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
|
|
|
|
size_t opaque_len, XlaCustomCallStatus* status) {
|
|
|
|
auto s = CholeskyUpdateImpl(stream, buffers, opaque, opaque_len);
|
|
|
|
if (!s.ok()) {
|
|
|
|
std::string_view message = s.message();
|
|
|
|
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
|
2024-06-27 09:24:15 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-10 12:08:30 -07:00
|
|
|
namespace {
|
2024-08-20 05:45:19 -07:00
|
|
|
ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in,
|
|
|
|
ffi::AnyBuffer vector_in,
|
|
|
|
ffi::Result<ffi::AnyBuffer> matrix_out,
|
|
|
|
ffi::Result<ffi::AnyBuffer> vector_out) {
|
|
|
|
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
|
|
|
|
SplitBatch2D(matrix_in.dimensions()));
|
|
|
|
if (rows != cols) {
|
|
|
|
return ffi::Error::InvalidArgument(
|
|
|
|
"The matrix input to Cholesky update must be square.");
|
2024-06-27 09:24:15 -07:00
|
|
|
}
|
2024-08-20 05:45:19 -07:00
|
|
|
FFI_RETURN_IF_ERROR(CheckShape(vector_in.dimensions(), {batch, cols},
|
|
|
|
"vector", "cholesky_update"));
|
|
|
|
FFI_RETURN_IF_ERROR(CheckShape(matrix_out->dimensions(), {batch, rows, cols},
|
|
|
|
"matrix_out", "cholesky_update"));
|
|
|
|
FFI_RETURN_IF_ERROR(CheckShape(vector_out->dimensions(), {batch, cols},
|
|
|
|
"vector_out", "cholesky_update"));
|
|
|
|
FFI_ASSIGN_OR_RETURN(auto size, MaybeCastNoOverflow<int>(cols));
|
|
|
|
auto dtype = matrix_in.element_type();
|
|
|
|
if (dtype != ffi::F32 && dtype != ffi::F64) {
|
|
|
|
return ffi::Error::InvalidArgument(
|
|
|
|
"Invalid input type for Cholesky update; must be float32 or float64.");
|
2024-06-27 09:24:15 -07:00
|
|
|
}
|
2024-08-20 05:45:19 -07:00
|
|
|
if (vector_in.element_type() != dtype ||
|
|
|
|
matrix_out->element_type() != dtype ||
|
|
|
|
vector_out->element_type() != dtype) {
|
|
|
|
return ffi::Error::InvalidArgument(
|
|
|
|
"All input and output types for Cholesky update must match.");
|
|
|
|
}
|
|
|
|
bool is_single_precision = dtype == ffi::F32;
|
|
|
|
auto matrix = matrix_out->untyped_data();
|
|
|
|
if (matrix_in.untyped_data() != matrix) {
|
|
|
|
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
|
|
|
gpuMemcpyAsync(matrix, matrix_in.untyped_data(), matrix_in.size_bytes(),
|
|
|
|
gpuMemcpyDeviceToDevice, stream)));
|
|
|
|
}
|
|
|
|
auto vector = vector_out->untyped_data();
|
|
|
|
if (vector_in.untyped_data() != vector) {
|
|
|
|
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
|
|
|
gpuMemcpyAsync(vector, vector_in.untyped_data(), vector_in.size_bytes(),
|
|
|
|
gpuMemcpyDeviceToDevice, stream)));
|
|
|
|
}
|
|
|
|
for (auto n = 0; n < batch; ++n) {
|
2024-08-26 17:03:27 -07:00
|
|
|
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(LaunchCholeskyUpdateFfiKernel(
|
|
|
|
stream, matrix, vector, size, is_single_precision)));
|
2024-08-20 05:45:19 -07:00
|
|
|
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
|
|
|
}
|
|
|
|
return ffi::Error::Success();
|
2024-08-08 07:35:06 -07:00
|
|
|
}
|
2024-08-20 05:45:19 -07:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl,
|
|
|
|
ffi::Ffi::Bind()
|
|
|
|
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
|
|
|
.Arg<ffi::AnyBuffer>()
|
|
|
|
.Arg<ffi::AnyBuffer>()
|
|
|
|
.Ret<ffi::AnyBuffer>()
|
|
|
|
.Ret<ffi::AnyBuffer>());
|
2024-08-08 07:35:06 -07:00
|
|
|
|
2024-08-20 05:45:19 -07:00
|
|
|
namespace {
|
2024-08-08 07:35:06 -07:00
|
|
|
ffi::Error LuPivotsToPermutationImpl(
|
|
|
|
gpuStream_t stream, ffi::Dictionary /* unused */,
|
|
|
|
ffi::Buffer<ffi::DataType::S32> pivots,
|
|
|
|
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation) {
|
2024-08-20 05:45:19 -07:00
|
|
|
FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]),
|
|
|
|
SplitBatch1D(pivots.dimensions()));
|
|
|
|
FFI_ASSIGN_OR_RETURN((auto [permutation_batch, permutation_size]),
|
|
|
|
SplitBatch1D(permutation->dimensions()));
|
2024-08-08 07:35:06 -07:00
|
|
|
if (permutation_batch != batch_size) {
|
|
|
|
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
|
|
|
"pivots and permutation must have the same batch size.");
|
|
|
|
}
|
|
|
|
if (permutation_size < pivot_size) {
|
|
|
|
return ffi::Error(
|
|
|
|
ffi::ErrorCode::kInvalidArgument,
|
|
|
|
absl::StrFormat("Output permutation size %d must match or exceed the "
|
|
|
|
"trailing dimension of the input pivots %d.",
|
|
|
|
permutation_size, pivot_size));
|
|
|
|
}
|
2024-07-09 11:06:54 -07:00
|
|
|
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
|
|
|
|
permutation_size, pivots.typed_data(),
|
|
|
|
permutation->typed_data());
|
2024-07-10 15:08:58 -07:00
|
|
|
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
2024-06-27 09:24:15 -07:00
|
|
|
return ffi::Error::Success();
|
2022-02-15 17:54:02 +00:00
|
|
|
}
|
2024-07-10 12:08:30 -07:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl,
|
|
|
|
ffi::Ffi::Bind()
|
|
|
|
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
2024-08-08 07:35:06 -07:00
|
|
|
// TODO(b/358275922): remove Attrs (and the
|
|
|
|
// unused Dictionary above) 12 weeks after
|
|
|
|
// release of jaxlib v0.4.32.
|
|
|
|
.Attrs()
|
2024-07-10 12:08:30 -07:00
|
|
|
.Arg<ffi::Buffer<ffi::DataType::S32>>()
|
|
|
|
.Ret<ffi::Buffer<ffi::DataType::S32>>());
|
2022-02-15 17:54:02 +00:00
|
|
|
|
2022-10-25 07:23:07 -07:00
|
|
|
} // namespace JAX_GPU_NAMESPACE
|
2022-05-06 13:47:23 -07:00
|
|
|
} // namespace jax
|