mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Migrate 'jaxlib' CPU custom-calls to the status-returning API
PiperOrigin-RevId: 438165260
This commit is contained in:
parent
b31cf89e48
commit
8884ce5b98
@ -145,6 +145,7 @@ cc_library(
|
||||
srcs = ["lapack_kernels.cc"],
|
||||
hdrs = ["lapack_kernels.h"],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
],
|
||||
)
|
||||
@ -198,6 +199,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":pocketfft_flatbuffers_cc",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
"@pocketfft",
|
||||
],
|
||||
|
@ -94,7 +94,9 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
Shape.array_shape(dtype, (), ()),
|
||||
Shape.array_shape(dtype, a_shape.dimensions(), layout),
|
||||
Shape.array_shape(dtype, b_shape.dimensions(), layout),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
jax_trsm = trsm
|
||||
|
||||
# # ?getrf: LU decomposition
|
||||
@ -149,7 +151,9 @@ def getrf(c, a):
|
||||
dtype,
|
||||
batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
# # ?geqrf: QR decomposition
|
||||
@ -212,7 +216,9 @@ def geqrf(c, a):
|
||||
dtype,
|
||||
batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
# # ?orgqr: product of elementary Householder reflectors:
|
||||
@ -282,7 +288,9 @@ def orgqr(c, a, tau):
|
||||
dtype,
|
||||
batch_dims + (k,),
|
||||
tuple(range(num_bd, -1, -1))),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(2))
|
||||
|
||||
|
||||
@ -326,7 +334,9 @@ def potrf(c, a, lower=False):
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(2))
|
||||
|
||||
|
||||
@ -420,7 +430,9 @@ def gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2),
|
||||
_ops.GetTupleElement(out, 3), _ops.GetTupleElement(out, 4))
|
||||
|
||||
@ -491,7 +503,9 @@ def syevd(c, a, lower=False):
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
||||
_ops.GetTupleElement(out, 2))
|
||||
|
||||
@ -575,7 +589,9 @@ def geev(c, a, jobvl=True, jobvr=True):
|
||||
Shape.array_shape(np.dtype(np.uint8), (), ()),
|
||||
Shape.array_shape(np.dtype(np.uint8), (), ()),
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
if real:
|
||||
return (_ops.Complex(_ops.GetTupleElement(out, 3),
|
||||
_ops.GetTupleElement(out, 4)),
|
||||
@ -653,7 +669,9 @@ def gees(c, a, jobvs=True, sort=False, select=None):
|
||||
Shape.array_shape(np.dtype(np.uint8), (), ()),
|
||||
Shape.array_shape(np.dtype(np.uint8), (), ()),
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
if sort == ord('S'):
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 3),
|
||||
_ops.GetTupleElement(out, 4), _ops.GetTupleElement(out, 5))
|
||||
|
@ -30,7 +30,7 @@ template <typename T>
|
||||
typename Trsm<T>::FnType* Trsm<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Trsm<T>::Kernel(void* out, void** data) {
|
||||
void Trsm<T>::Kernel(void* out, void** data, XlaCustomCallStatus*) {
|
||||
int32_t left_side = *reinterpret_cast<int32_t*>(data[0]);
|
||||
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
|
||||
int32_t trans_a = *reinterpret_cast<int32_t*>(data[2]);
|
||||
@ -82,7 +82,7 @@ template <typename T>
|
||||
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Getrf<T>::Kernel(void* out_tuple, void** data) {
|
||||
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int m = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -116,7 +116,7 @@ template <typename T>
|
||||
typename Geqrf<T>::FnType* Geqrf<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Geqrf<T>::Kernel(void* out_tuple, void** data) {
|
||||
void Geqrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int m = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -163,7 +163,7 @@ template <typename T>
|
||||
typename Orgqr<T>::FnType* Orgqr<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Orgqr<T>::Kernel(void* out_tuple, void** data) {
|
||||
void Orgqr<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int m = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -211,7 +211,7 @@ template <typename T>
|
||||
typename Potrf<T>::FnType* Potrf<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Potrf<T>::Kernel(void* out_tuple, void** data) {
|
||||
void Potrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int b = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -260,7 +260,7 @@ template <typename T>
|
||||
typename RealGesdd<T>::FnType* RealGesdd<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void RealGesdd<T>::Kernel(void* out_tuple, void** data) {
|
||||
void RealGesdd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int32_t job_opt_full_matrices = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int32_t job_opt_compute_uv = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int b = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -332,7 +332,8 @@ template <typename T>
|
||||
typename ComplexGesdd<T>::FnType* ComplexGesdd<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void ComplexGesdd<T>::Kernel(void* out_tuple, void** data) {
|
||||
void ComplexGesdd<T>::Kernel(void* out_tuple, void** data,
|
||||
XlaCustomCallStatus*) {
|
||||
int32_t job_opt_full_matrices = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int32_t job_opt_compute_uv = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int b = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -411,7 +412,7 @@ template <typename T>
|
||||
typename RealSyevd<T>::FnType* RealSyevd<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void RealSyevd<T>::Kernel(void* out_tuple, void** data) {
|
||||
void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int b = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -459,7 +460,8 @@ template <typename T>
|
||||
typename ComplexHeevd<T>::FnType* ComplexHeevd<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void ComplexHeevd<T>::Kernel(void* out_tuple, void** data) {
|
||||
void ComplexHeevd<T>::Kernel(void* out_tuple, void** data,
|
||||
XlaCustomCallStatus*) {
|
||||
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int b = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
||||
@ -531,7 +533,7 @@ template <typename T>
|
||||
typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void RealGeev<T>::Kernel(void* out_tuple, void** data) {
|
||||
void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
@ -590,7 +592,8 @@ template <typename T>
|
||||
typename ComplexGeev<T>::FnType* ComplexGeev<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void ComplexGeev<T>::Kernel(void* out_tuple, void** data) {
|
||||
void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
|
||||
XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
@ -648,7 +651,7 @@ template <typename T>
|
||||
typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void RealGees<T>::Kernel(void* out_tuple, void** data) {
|
||||
void RealGees<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
@ -708,7 +711,8 @@ template <typename T>
|
||||
typename ComplexGees<T>::FnType* ComplexGees<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void ComplexGees<T>::Kernel(void* out_tuple, void** data) {
|
||||
void ComplexGees<T>::Kernel(void* out_tuple, void** data,
|
||||
XlaCustomCallStatus*) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
|
||||
// by the pybind wrapper that links them to an existing SciPy lapack instance,
|
||||
@ -35,7 +36,7 @@ struct Trsm {
|
||||
lapack_int* lda, T* b, lapack_int* ldb);
|
||||
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -44,7 +45,7 @@ struct Getrf {
|
||||
lapack_int* ipiv, lapack_int* info);
|
||||
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -53,7 +54,7 @@ struct Geqrf {
|
||||
T* tau, T* work, lapack_int* lwork, lapack_int* info);
|
||||
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int m, lapack_int n);
|
||||
};
|
||||
@ -64,7 +65,7 @@ struct Orgqr {
|
||||
lapack_int* lda, T* tau, T* work, lapack_int* lwork,
|
||||
lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k);
|
||||
};
|
||||
|
||||
@ -73,7 +74,7 @@ struct Potrf {
|
||||
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
|
||||
lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
lapack_int GesddIworkSize(int64_t m, int64_t n);
|
||||
@ -85,7 +86,7 @@ struct RealGesdd {
|
||||
lapack_int* ldvt, T* work, lapack_int* lwork,
|
||||
lapack_int* iwork, lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int m, lapack_int n,
|
||||
bool job_opt_compute_uv, bool job_opt_full_matrices);
|
||||
@ -101,7 +102,7 @@ struct ComplexGesdd {
|
||||
lapack_int* lwork, typename T::value_type* rwork,
|
||||
lapack_int* iwork, lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int m, lapack_int n,
|
||||
bool job_opt_compute_uv, bool job_opt_full_matrices);
|
||||
@ -117,7 +118,7 @@ struct RealSyevd {
|
||||
lapack_int* lda, T* w, T* work, lapack_int* lwork,
|
||||
lapack_int* iwork, lapack_int* liwork, lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
lapack_int HeevdWorkSize(int64_t n);
|
||||
@ -131,7 +132,7 @@ struct ComplexHeevd {
|
||||
lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork,
|
||||
lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -141,7 +142,7 @@ struct RealGeev {
|
||||
T* vr, lapack_int* ldvr, T* work, lapack_int* lwork,
|
||||
lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -151,7 +152,7 @@ struct ComplexGeev {
|
||||
lapack_int* ldvr, T* work, lapack_int* lwork,
|
||||
typename T::value_type* rwork, lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -161,7 +162,7 @@ struct RealGees {
|
||||
T* wr, T* wi, T* vs, lapack_int* ldvs, T* work,
|
||||
lapack_int* lwork, bool* bwork, lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -172,7 +173,7 @@ struct ComplexGees {
|
||||
typename T::value_type* rwork, bool* bwork,
|
||||
lapack_int* info);
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data);
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
} // namespace jax
|
||||
|
@ -20,7 +20,6 @@ from . import _pocketfft
|
||||
from . import pocketfft_flatbuffers_py_generated as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
import flatbuffers
|
||||
from jaxlib import xla_client
|
||||
|
||||
@ -53,8 +52,9 @@ def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
|
||||
shape, fft_lengths)
|
||||
assert list(
|
||||
shape.dimensions())[-len(fft_lengths):] == fft_lengths, (shape,
|
||||
fft_lengths)
|
||||
out_shape = list(shape.dimensions())
|
||||
out_shape[-1] = out_shape[-1] // 2 + 1
|
||||
|
||||
@ -80,8 +80,9 @@ def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
|
||||
shape, fft_lengths)
|
||||
assert list(
|
||||
shape.dimensions())[-len(fft_lengths):] == fft_lengths, (shape,
|
||||
fft_lengths)
|
||||
out_shape = shape.dimensions()
|
||||
|
||||
# PocketFft does not allow size 0 dimensions.
|
||||
@ -156,4 +157,6 @@ def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
np.dtype(np.uint8), (len(descriptor_bytes),), (0,)),
|
||||
xla_client.Shape.array_shape(dtype, shape.dimensions(),
|
||||
tuple(range(n - 1, -1, -1))),
|
||||
))
|
||||
),
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
|
@ -18,10 +18,11 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "pocketfft/pocketfft_hdronly.h"
|
||||
#include "jaxlib/pocketfft_generated.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
void PocketFft(void* out, void** in) {
|
||||
void PocketFft(void* out, void** in, XlaCustomCallStatus*) {
|
||||
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
|
||||
pocketfft::shape_t shape(descriptor->shape()->begin(),
|
||||
descriptor->shape()->end());
|
||||
|
@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
void PocketFft(void* out, void** in);
|
||||
void PocketFft(void* out, void** in, XlaCustomCallStatus*);
|
||||
|
||||
} // namespace jax
|
||||
|
Loading…
x
Reference in New Issue
Block a user