mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.) This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release). Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released. PiperOrigin-RevId: 730928808
This commit is contained in:
parent
05614edc7d
commit
2ce88c950a
@ -41,6 +41,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lax.lax import _float, _complex, _int
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -2390,12 +2391,17 @@ def _triangular_solve_cpu_lower(
|
||||
conjugate_a = False
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype)
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
||||
alpha_aval = ShapedArray((), a_aval.dtype)
|
||||
# TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18.
|
||||
if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1):
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)),
|
||||
alpha_aval = ShapedArray((), a_aval.dtype),
|
||||
else:
|
||||
alpha = ()
|
||||
alpha_aval = ()
|
||||
rule = _linalg_ffi_lowering(target_name,
|
||||
[a_aval, b_aval, alpha_aval],
|
||||
[a_aval, b_aval, *alpha_aval],
|
||||
operand_output_aliases={1: 0})
|
||||
return rule(ctx, a, b, alpha,
|
||||
return rule(ctx, a, b, *alpha,
|
||||
side=_matrix_side_attr(left_side),
|
||||
uplo=_matrix_uplo_attr(lower),
|
||||
trans_x=_matrix_transpose_attr(transpose_a, conjugate_a),
|
||||
|
@ -146,7 +146,10 @@ template struct Trsm<std::complex<double>>;
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y, ffi::BufferR0<dtype> alpha,
|
||||
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y,
|
||||
// TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after
|
||||
// the release of JAX 0.5.2.
|
||||
ffi::RemainingArgs,
|
||||
ffi::ResultBuffer<dtype> y_out, MatrixParams::Side side,
|
||||
MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x,
|
||||
MatrixParams::Diag diag) {
|
||||
@ -168,10 +171,10 @@ ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
|
||||
auto* x_data = x.typed_data();
|
||||
const int64_t y_out_step{y_rows * y_cols};
|
||||
const int64_t x_step{x_leading_dim_v * x_leading_dim_v};
|
||||
ffi::NativeType<dtype> alpha = static_cast<ffi::NativeType<dtype>>(1);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v,
|
||||
alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data,
|
||||
&y_leading_dim_v);
|
||||
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, &alpha,
|
||||
x_data, &x_leading_dim_v, y_out_data, &y_leading_dim_v);
|
||||
|
||||
y_out_data += y_out_step;
|
||||
x_data += x_step;
|
||||
@ -2241,17 +2244,17 @@ template struct TridiagonalSolver<ffi::DataType::C128>;
|
||||
|
||||
// FFI Definition Macros (by DataType)
|
||||
|
||||
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, TriMatrixEquationSolver<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
|
||||
.Arg<::xla::ffi::BufferR0<data_type>>(/*alpha*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
|
||||
.Attr<MatrixParams::Side>("side") \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Attr<MatrixParams::Transpose>("trans_x") \
|
||||
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, TriMatrixEquationSolver<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
|
||||
.RemainingArgs() \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
|
||||
.Attr<MatrixParams::Side>("side") \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Attr<MatrixParams::Transpose>("trans_x") \
|
||||
.Attr<MatrixParams::Diag>("diag"))
|
||||
|
||||
#define JAX_CPU_DEFINE_GETRF(name, data_type) \
|
||||
|
@ -147,7 +147,7 @@ struct TriMatrixEquationSolver {
|
||||
inline static FnType* fn = nullptr;
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, ::xla::ffi::Buffer<dtype> y,
|
||||
::xla::ffi::BufferR0<dtype> alpha, ::xla::ffi::ResultBuffer<dtype> y_out,
|
||||
::xla::ffi::RemainingArgs, ::xla::ffi::ResultBuffer<dtype> y_out,
|
||||
MatrixParams::Side side, MatrixParams::UpLo uplo,
|
||||
MatrixParams::Transpose trans_x, MatrixParams::Diag diag);
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user