Activate Triangular Solve to XLA's FFI

PiperOrigin-RevId: 705029286
This commit is contained in:
Paweł Paruzel 2024-12-11 02:21:56 -08:00 committed by jax authors
parent 3d9c720d42
commit 1256153200
9 changed files with 354 additions and 49 deletions

View File

@ -1017,6 +1017,7 @@ _CPU_FFI_KERNELS = [
"lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi",
"lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi",
"lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi",
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
]
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.

View File

@ -203,3 +203,281 @@ module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0bO\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02\x82\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x0b\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ztrsm\x00",
xla_call_module_version=6,
) # End paste
data_2024_12_02 = {}
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['c128'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_ztrsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array([
[5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j],
[12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j],
]),
array([
[0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j],
[5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j],
[10.0 + 0.0j, 11.0 + 0.0j, 12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j],
[15.0 + 0.0j, 16.0 + 0.0j, 17.0 + 0.0j, 18.0 + 0.0j, 19.0 + 0.0j],
]),
),
expected_outputs=(
array([
[
0.0 + 0.0j,
0.2 + 0.0j,
0.4 + 0.0j,
0.6000000000000001 + 0.0j,
0.8 + 0.0j,
],
[
0.5 + 0.0j,
0.52 + 0.0j,
0.54 + 0.0j,
0.5599999999999999 + 0.0j,
0.58 + 0.0j,
],
[
0.36666666666666664 + 0.0j,
0.3146666666666667 + 0.0j,
0.2626666666666667 + 0.0j,
0.21066666666666667 + 0.0j,
0.15866666666666665 + 0.0j,
],
[
0.16833333333333336 + 0.0j,
0.1217333333333333 + 0.0j,
0.07513333333333323 + 0.0j,
0.0285333333333333 + 0.0j,
-0.018066666666666675 + 0.0j,
],
]),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xcomplex<f64>> loc("a"), %arg1: tensor<4x5xcomplex<f64>> loc("b")) -> (tensor<4x5xcomplex<f64>> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>> loc(#loc4)
%0 = stablehlo.custom_call @lapack_ztrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex<f64>>, tensor<4x5xcomplex<f64>>, tensor<complex<f64>>) -> tensor<4x5xcomplex<f64>> loc(#loc4)
return %0 : tensor<4x5xcomplex<f64>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0bO+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\x1e\x03\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\x0b)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ztrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['c64'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_ctrsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array(
[
[5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j],
[12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j],
],
dtype=complex64,
),
array(
[
[0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j],
[5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j],
[
10.0 + 0.0j,
11.0 + 0.0j,
12.0 + 0.0j,
13.0 + 0.0j,
14.0 + 0.0j,
],
[
15.0 + 0.0j,
16.0 + 0.0j,
17.0 + 0.0j,
18.0 + 0.0j,
19.0 + 0.0j,
],
],
dtype=complex64,
),
),
expected_outputs=(
array(
[
[0.0 + 0.0j, 0.2 + 0.0j, 0.4 + 0.0j, 0.6 + 0.0j, 0.8 + 0.0j],
[
0.5 + 0.0j,
0.52 + 0.0j,
0.54 + 0.0j,
0.56 + 0.0j,
0.58000004 + 0.0j,
],
[
0.36666667 + 0.0j,
0.31466666 + 0.0j,
0.26266667 + 0.0j,
0.21066667 + 0.0j,
0.15866666 + 0.0j,
],
[
0.16833334 + 0.0j,
0.12173338 + 0.0j,
0.0751333 + 0.0j,
0.02853328 + 0.0j,
-0.018066704 + 0.0j,
],
],
dtype=complex64,
),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xcomplex<f32>> loc("a"), %arg1: tensor<4x5xcomplex<f32>> loc("b")) -> (tensor<4x5xcomplex<f32>> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>> loc(#loc4)
%0 = stablehlo.custom_call @lapack_ctrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex<f32>>, tensor<4x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<4x5xcomplex<f32>> loc(#loc4)
return %0 : tensor<4x5xcomplex<f32>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\xfe\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\t)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ctrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['f32'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_strsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array(
[
[5.0, 0.0, 0.0, 0.0],
[4.0, 10.0, 0.0, 0.0],
[8.0, 9.0, 15.0, 0.0],
[12.0, 13.0, 14.0, 20.0],
],
dtype=float32,
),
array(
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
],
dtype=float32,
),
),
expected_outputs=(
array(
[
[0.0, 0.2, 0.4, 0.6, 0.8],
[0.5, 0.52, 0.54, 0.56, 0.58000004],
[0.36666667, 0.31466666, 0.26266667, 0.21066667, 0.15866666],
[0.16833334, 0.12173338, 0.0751333, 0.02853328, -0.018066704],
],
dtype=float32,
),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xf32> loc("a"), %arg1: tensor<4x5xf32> loc("b")) -> (tensor<4x5xf32> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc4)
%0 = stablehlo.custom_call @lapack_strsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf32>, tensor<4x5xf32>, tensor<f32>) -> tensor<4x5xf32> loc(#loc4)
return %0 : tensor<4x5xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x1f+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xe6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\t\x00\x00\x80?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_strsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['f64'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_dtrsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array([
[5.0, 0.0, 0.0, 0.0],
[4.0, 10.0, 0.0, 0.0],
[8.0, 9.0, 15.0, 0.0],
[12.0, 13.0, 14.0, 20.0],
]),
array([
[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
]),
),
expected_outputs=(
array([
[0.0, 0.2, 0.4, 0.6000000000000001, 0.8],
[0.5, 0.52, 0.54, 0.5599999999999999, 0.58],
[
0.36666666666666664,
0.3146666666666667,
0.2626666666666667,
0.21066666666666667,
0.15866666666666665,
],
[
0.16833333333333336,
0.1217333333333333,
0.07513333333333323,
0.0285333333333333,
-0.018066666666666675,
],
]),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xf64> loc("a"), %arg1: tensor<4x5xf64> loc("b")) -> (tensor<4x5xf64> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64> loc(#loc4)
%0 = stablehlo.custom_call @lapack_dtrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf64>, tensor<4x5xf64>, tensor<f64>) -> tensor<4x5xf64> loc(#loc4)
return %0 : tensor<4x5xf64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xf6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_dtrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste

View File

@ -1329,7 +1329,6 @@ def _triangular_solve_lowering(
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
def _triangular_solve_cpu_lower(
ctx, a, b, *, left_side, lower, transpose_a,
@ -1342,8 +1341,10 @@ def _triangular_solve_cpu_lower(
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
return lapack.trsm_hlo(
a_aval.dtype, alpha,
*ctx_args, a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)
else:
@ -1358,6 +1359,8 @@ def _triangular_solve_cpu_lower(
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')

View File

@ -117,10 +117,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
// FFI Kernels
JAX_CPU_REGISTER_HANDLER(blas_strsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_dtrsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_ctrsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_ztrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_strsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dtrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ctrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ztrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);

View File

@ -234,10 +234,10 @@ nb::dict Registrations() {
dict["lapack_zhetrd"] =
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
dict["blas_strsm_ffi"] = EncapsulateFunction(blas_strsm_ffi);
dict["blas_dtrsm_ffi"] = EncapsulateFunction(blas_dtrsm_ffi);
dict["blas_ctrsm_ffi"] = EncapsulateFunction(blas_ctrsm_ffi);
dict["blas_ztrsm_ffi"] = EncapsulateFunction(blas_ztrsm_ffi);
dict["lapack_strsm_ffi"] = EncapsulateFunction(lapack_strsm_ffi);
dict["lapack_dtrsm_ffi"] = EncapsulateFunction(lapack_dtrsm_ffi);
dict["lapack_ctrsm_ffi"] = EncapsulateFunction(lapack_ctrsm_ffi);
dict["lapack_ztrsm_ffi"] = EncapsulateFunction(lapack_ztrsm_ffi);
dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi);
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);

View File

@ -2128,10 +2128,10 @@ template struct TridiagonalReduction<ffi::DataType::C128>;
// FFI Handlers
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_TRSM(lapack_strsm_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_TRSM(lapack_dtrsm_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_TRSM(lapack_ctrsm_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_TRSM(lapack_ztrsm_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);

View File

@ -741,10 +741,10 @@ struct TridiagonalReduction {
};
// Declare all the handler symbols
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ctrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ztrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_strsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dtrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ctrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ztrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgetrf_ffi);

View File

@ -118,37 +118,32 @@ def build_lapack_fn_target(fn_base: str, dtype) -> str:
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
def trsm_hlo(dtype, alpha, a, b,
def trsm_hlo(ctx, dtype, alpha, a, b,
left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False, *,
b_shape_vals: tuple[DimensionSize, ...]):
_lapack.initialize()
b_type = ir.RankedTensorType(b.type)
m, n = b_shape_vals[-2:]
batch_dims_vals = b_shape_vals[:-2]
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
if dtype == np.float32:
fn = "blas_strsm"
elif dtype == np.float64:
fn = "blas_dtrsm"
elif dtype == np.complex64:
fn = "blas_ctrsm"
elif dtype == np.complex128:
fn = "blas_ztrsm"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
fn_base = prepare_lapack_call(fn_base="trsm", dtype=dtype)
b_type = ir.RankedTensorType(b.type)
batch_dims_vals = b_shape_vals[:-2]
num_bd = len(batch_dims_vals)
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
result_types, result_shapes = mk_result_types_and_shapes(
[(b_shape_vals, b_type.element_type)])
if ctx.is_forward_compat():
# The old TRSM kernel name is prefixed with "blas"
fn = fn_base.replace("lapack", "blas", 1)
m, n = b_shape_vals[-2:]
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
result_types, result_shapes = mk_result_types_and_shapes(
[(b_shape_vals, b_type.element_type)]
)
return custom_call(
fn,
result_types=result_types,
@ -162,6 +157,25 @@ def trsm_hlo(dtype, alpha, a, b,
result_shapes=result_shapes,
).results
fn = fn_base + "_ffi"
return custom_call(
fn,
result_types=result_types,
operands=[a, b, alpha],
operand_layouts=[layout] * 2 + [scalar_layout],
result_layouts=[layout],
operand_output_aliases={1: 0},
result_shapes=result_shapes,
backend_config={
"side": _matrix_side_attr(left_side=left_side),
"uplo": _matrix_uplo_attr(lower=lower),
"trans_x": _matrix_transpose_attr(
transpose=trans_a, conjugate=conj_a
),
"diag": _matrix_diagonal_attr(unit_diag=diag),
},
api_version=4,
).results
# ?potrf: Cholesky decomposition

View File

@ -122,6 +122,7 @@ class CompatTest(bctu.CompatTestBase):
cpu_eigh_lapack_syev.data_2024_08_19,
cpu_lu_lapack_getrf.data_2024_05_31,
cpu_schur_lapack_gees.data_2024_11_29,
cpu_triangular_solve_blas_trsm.data_2024_12_02,
cpu_svd_lapack_gesdd.data_2024_08_13,
cpu_hessenberg_lapack_gehrd.data_2024_08_31,
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01,
@ -741,6 +742,14 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)