From 2ce88c950abcafa4bbd934b5bee61cc61011f8e3 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 25 Feb 2025 10:03:48 -0800 Subject: [PATCH] Deprecate alpha argument to trsm LAPACK kernel. (Part of general cleanups of the lax.linalg submodule.) This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release). Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released. PiperOrigin-RevId: 730928808 --- jax/_src/lax/linalg.py | 14 ++++++++++---- jaxlib/cpu/lapack_kernels.cc | 33 ++++++++++++++++++--------------- jaxlib/cpu/lapack_kernels.h | 2 +- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ad7082504..0336eff9c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -41,6 +41,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2390,12 +2391,17 @@ def _triangular_solve_cpu_lower( conjugate_a = False if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) - alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) - alpha_aval = ShapedArray((), a_aval.dtype) + # TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18. + if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1): + alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)), + alpha_aval = ShapedArray((), a_aval.dtype), + else: + alpha = () + alpha_aval = () rule = _linalg_ffi_lowering(target_name, - [a_aval, b_aval, alpha_aval], + [a_aval, b_aval, *alpha_aval], operand_output_aliases={1: 0}) - return rule(ctx, a, b, alpha, + return rule(ctx, a, b, *alpha, side=_matrix_side_attr(left_side), uplo=_matrix_uplo_attr(lower), trans_x=_matrix_transpose_attr(transpose_a, conjugate_a), diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 894ab13bb..ddc93261e 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -146,7 +146,10 @@ template struct Trsm>; template ffi::Error TriMatrixEquationSolver::Kernel( - ffi::Buffer x, ffi::Buffer y, ffi::BufferR0 alpha, + ffi::Buffer x, ffi::Buffer y, + // TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after + // the release of JAX 0.5.2. + ffi::RemainingArgs, ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { @@ -168,10 +171,10 @@ ffi::Error TriMatrixEquationSolver::Kernel( auto* x_data = x.typed_data(); const int64_t y_out_step{y_rows * y_cols}; const int64_t x_step{x_leading_dim_v * x_leading_dim_v}; + ffi::NativeType alpha = static_cast>(1); for (int64_t i = 0; i < batch_count; ++i) { - fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, - alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data, - &y_leading_dim_v); + fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, &alpha, + x_data, &x_leading_dim_v, y_out_data, &y_leading_dim_v); y_out_data += y_out_step; x_data += x_step; @@ -2241,17 +2244,17 @@ template struct TridiagonalSolver; // FFI Definition Macros (by DataType) -#define JAX_CPU_DEFINE_TRSM(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, TriMatrixEquationSolver::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Arg<::xla::ffi::Buffer>(/*y*/) \ - .Arg<::xla::ffi::BufferR0>(/*alpha*/) \ - .Ret<::xla::ffi::Buffer>(/*y_out*/) \ - .Attr("side") \ - .Attr("uplo") \ - .Attr("trans_x") \ +#define JAX_CPU_DEFINE_TRSM(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, TriMatrixEquationSolver::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*y*/) \ + .RemainingArgs() \ + .Ret<::xla::ffi::Buffer>(/*y_out*/) \ + .Attr("side") \ + .Attr("uplo") \ + .Attr("trans_x") \ .Attr("diag")) #define JAX_CPU_DEFINE_GETRF(name, data_type) \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index d94b5af61..e075ff293 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -147,7 +147,7 @@ struct TriMatrixEquationSolver { inline static FnType* fn = nullptr; static ::xla::ffi::Error Kernel( ::xla::ffi::Buffer x, ::xla::ffi::Buffer y, - ::xla::ffi::BufferR0 alpha, ::xla::ffi::ResultBuffer y_out, + ::xla::ffi::RemainingArgs, ::xla::ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag); };