diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ad7082504..0336eff9c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -41,6 +41,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2390,12 +2391,17 @@ def _triangular_solve_cpu_lower( conjugate_a = False if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) - alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) - alpha_aval = ShapedArray((), a_aval.dtype) + # TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18. + if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1): + alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)), + alpha_aval = ShapedArray((), a_aval.dtype), + else: + alpha = () + alpha_aval = () rule = _linalg_ffi_lowering(target_name, - [a_aval, b_aval, alpha_aval], + [a_aval, b_aval, *alpha_aval], operand_output_aliases={1: 0}) - return rule(ctx, a, b, alpha, + return rule(ctx, a, b, *alpha, side=_matrix_side_attr(left_side), uplo=_matrix_uplo_attr(lower), trans_x=_matrix_transpose_attr(transpose_a, conjugate_a), diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 894ab13bb..ddc93261e 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -146,7 +146,10 @@ template struct Trsm>; template ffi::Error TriMatrixEquationSolver::Kernel( - ffi::Buffer x, ffi::Buffer y, ffi::BufferR0 alpha, + ffi::Buffer x, ffi::Buffer y, + // TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after + // the release of JAX 0.5.2. + ffi::RemainingArgs, ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { @@ -168,10 +171,10 @@ ffi::Error TriMatrixEquationSolver::Kernel( auto* x_data = x.typed_data(); const int64_t y_out_step{y_rows * y_cols}; const int64_t x_step{x_leading_dim_v * x_leading_dim_v}; + ffi::NativeType alpha = static_cast>(1); for (int64_t i = 0; i < batch_count; ++i) { - fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, - alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data, - &y_leading_dim_v); + fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, &alpha, + x_data, &x_leading_dim_v, y_out_data, &y_leading_dim_v); y_out_data += y_out_step; x_data += x_step; @@ -2241,17 +2244,17 @@ template struct TridiagonalSolver; // FFI Definition Macros (by DataType) -#define JAX_CPU_DEFINE_TRSM(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, TriMatrixEquationSolver::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Arg<::xla::ffi::Buffer>(/*y*/) \ - .Arg<::xla::ffi::BufferR0>(/*alpha*/) \ - .Ret<::xla::ffi::Buffer>(/*y_out*/) \ - .Attr("side") \ - .Attr("uplo") \ - .Attr("trans_x") \ +#define JAX_CPU_DEFINE_TRSM(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, TriMatrixEquationSolver::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*y*/) \ + .RemainingArgs() \ + .Ret<::xla::ffi::Buffer>(/*y_out*/) \ + .Attr("side") \ + .Attr("uplo") \ + .Attr("trans_x") \ .Attr("diag")) #define JAX_CPU_DEFINE_GETRF(name, data_type) \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index d94b5af61..e075ff293 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -147,7 +147,7 @@ struct TriMatrixEquationSolver { inline static FnType* fn = nullptr; static ::xla::ffi::Error Kernel( ::xla::ffi::Buffer x, ::xla::ffi::Buffer y, - ::xla::ffi::BufferR0 alpha, ::xla::ffi::ResultBuffer y_out, + ::xla::ffi::RemainingArgs, ::xla::ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag); };