Deprecate alpha argument to trsm LAPACK kernel.

(Part of general cleanups of the lax.linalg submodule.)

This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).

Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.

PiperOrigin-RevId: 730928808
This commit is contained in:
Dan Foreman-Mackey 2025-02-25 10:03:48 -08:00 committed by jax authors
parent 05614edc7d
commit 2ce88c950a
3 changed files with 29 additions and 20 deletions

View File

@ -41,6 +41,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lax import utils as lax_utils
from jax._src.lax.lax import _float, _complex, _int
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -2390,12 +2391,17 @@ def _triangular_solve_cpu_lower(
conjugate_a = False
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype)
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
alpha_aval = ShapedArray((), a_aval.dtype)
# TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18.
if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1):
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)),
alpha_aval = ShapedArray((), a_aval.dtype),
else:
alpha = ()
alpha_aval = ()
rule = _linalg_ffi_lowering(target_name,
[a_aval, b_aval, alpha_aval],
[a_aval, b_aval, *alpha_aval],
operand_output_aliases={1: 0})
return rule(ctx, a, b, alpha,
return rule(ctx, a, b, *alpha,
side=_matrix_side_attr(left_side),
uplo=_matrix_uplo_attr(lower),
trans_x=_matrix_transpose_attr(transpose_a, conjugate_a),

View File

@ -146,7 +146,10 @@ template struct Trsm<std::complex<double>>;
template <ffi::DataType dtype>
ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y, ffi::BufferR0<dtype> alpha,
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y,
// TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after
// the release of JAX 0.5.2.
ffi::RemainingArgs,
ffi::ResultBuffer<dtype> y_out, MatrixParams::Side side,
MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x,
MatrixParams::Diag diag) {
@ -168,10 +171,10 @@ ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
auto* x_data = x.typed_data();
const int64_t y_out_step{y_rows * y_cols};
const int64_t x_step{x_leading_dim_v * x_leading_dim_v};
ffi::NativeType<dtype> alpha = static_cast<ffi::NativeType<dtype>>(1);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v,
alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data,
&y_leading_dim_v);
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, &alpha,
x_data, &x_leading_dim_v, y_out_data, &y_leading_dim_v);
y_out_data += y_out_step;
x_data += x_step;
@ -2241,17 +2244,17 @@ template struct TridiagonalSolver<ffi::DataType::C128>;
// FFI Definition Macros (by DataType)
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, TriMatrixEquationSolver<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
.Arg<::xla::ffi::BufferR0<data_type>>(/*alpha*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
.Attr<MatrixParams::Side>("side") \
.Attr<MatrixParams::UpLo>("uplo") \
.Attr<MatrixParams::Transpose>("trans_x") \
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, TriMatrixEquationSolver<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
.RemainingArgs() \
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
.Attr<MatrixParams::Side>("side") \
.Attr<MatrixParams::UpLo>("uplo") \
.Attr<MatrixParams::Transpose>("trans_x") \
.Attr<MatrixParams::Diag>("diag"))
#define JAX_CPU_DEFINE_GETRF(name, data_type) \

View File

@ -147,7 +147,7 @@ struct TriMatrixEquationSolver {
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::Buffer<dtype> y,
::xla::ffi::BufferR0<dtype> alpha, ::xla::ffi::ResultBuffer<dtype> y_out,
::xla::ffi::RemainingArgs, ::xla::ffi::ResultBuffer<dtype> y_out,
MatrixParams::Side side, MatrixParams::UpLo uplo,
MatrixParams::Transpose trans_x, MatrixParams::Diag diag);
};