mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
This commit is contained in:
parent
c1de7c733d
commit
39ce7916f1
@ -1021,10 +1021,26 @@ _CPU_FFI_KERNELS = [
|
||||
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
|
||||
"lapack_sgtsv_ffi", "lapack_dgtsv_ffi", "lapack_cgtsv_ffi", "lapack_zgtsv_ffi",
|
||||
]
|
||||
_GPU_FFI_KERNELS = [
|
||||
# lu on GPU
|
||||
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
|
||||
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
|
||||
# qr on GPU
|
||||
"cusolver_geqrf_ffi", "cusolver_orgqr_ffi",
|
||||
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi",
|
||||
# eigh on GPU
|
||||
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
|
||||
# svd on GPU
|
||||
"cusolver_gesvd_ffi", "cusolver_gesvdj_ffi",
|
||||
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
|
||||
# tridiagonal on GPU
|
||||
"cusolver_sytrd_ffi",
|
||||
]
|
||||
# These are the JAX custom call target names that are guaranteed to be stable.
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
*_CPU_FFI_KERNELS,
|
||||
*_GPU_FFI_KERNELS,
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"cu_threefry2x32", "cu_threefry2x32_ffi",
|
||||
# Triton IR does not guarantee stability.
|
||||
@ -1047,18 +1063,6 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd",
|
||||
# hessenberg on CPU
|
||||
"lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd",
|
||||
# lu on GPU
|
||||
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
|
||||
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
|
||||
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
|
||||
# qr on GPU
|
||||
"cusolver_geqrf_ffi", "cusolver_orgqr_ffi",
|
||||
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi",
|
||||
# eigh on GPU
|
||||
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
|
||||
# svd on GPU
|
||||
"cusolver_gesvd_ffi", "cusolver_gesvdj_ffi",
|
||||
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
|
||||
# lu on TPU
|
||||
"LuDecomposition",
|
||||
# ApproxTopK on TPU
|
||||
|
@ -0,0 +1,388 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
||||
data_2025_01_09 = {}
|
||||
|
||||
data_2025_01_09["f32"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_sytrd_ffi'],
|
||||
serialized_date=datetime.date(2025, 1, 9),
|
||||
inputs=(array([[[-1.2178137 , -1.4850428 , 2.5807054 , -2.7725415 ],
|
||||
[ 0.7619374 , -4.455563 , -2.7315347 , 0.73871464],
|
||||
[ 1.1837728 , -2.3819368 , 0.77078664, 0.5918046 ],
|
||||
[-4.019883 , 1.7254075 , 0.93185544, -0.20608105]],
|
||||
|
||||
[[-1.7672386 , -4.4935865 , 7.287687 , 1.945105 ],
|
||||
[-0.3699937 , 3.9466763 , -2.143978 , -0.12787771],
|
||||
[-0.02148095, 3.0340018 , 6.5158176 , -1.6688819 ],
|
||||
[ 3.5351803 , 2.8901145 , 3.4657938 , 0.06964709]]],
|
||||
dtype=float32),),
|
||||
expected_outputs=(array([[[-1.2178137e+00, -1.4850428e+00, 2.5807054e+00, -2.7725415e+00],
|
||||
[-4.2592635e+00, -1.5749537e+00, -2.7315347e+00, 7.3871464e-01],
|
||||
[ 2.3575491e-01, 2.9705398e+00, -4.0452647e+00, 5.9180462e-01],
|
||||
[-8.0058199e-01, -9.9736643e-01, -1.4680552e-01, 1.7293599e+00]],
|
||||
|
||||
[[-1.7672386e+00, -4.4935865e+00, 7.2876868e+00, 1.9451050e+00],
|
||||
[ 3.5545545e+00, -5.2433920e-01, -2.1439781e+00, -1.2787771e-01],
|
||||
[ 5.4734843e-03, -3.9149244e+00, 9.0688534e+00, -1.6688819e+00],
|
||||
[-9.0078670e-01, 3.4653693e-01, 1.6330767e-01, 1.9876275e+00]]],
|
||||
dtype=float32), array([[-1.2178137, -1.5749537, -4.0452647, 1.7293599],
|
||||
[-1.7672386, -0.5243392, 9.068853 , 1.9876275]], dtype=float32), array([[-4.2592635 , 2.9705398 , -0.14680552],
|
||||
[ 3.5545545 , -3.9149244 , 0.16330767]], dtype=float32), array([[1.1788895, 1.002637 , 0. ],
|
||||
[1.10409 , 1.7855742, 0. ]], dtype=float32)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("x")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xf32> loc("x")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x3xf32> {jax.result_info = "[2]"}, tensor<2x3xf32> {jax.result_info = "[3]"}) {
|
||||
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_sytrd_ffi(%arg0) {mhlo.backend_config = {lower = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x4x4xi1> loc(#loc5)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc5)
|
||||
%5 = stablehlo.select %3, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc6)
|
||||
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x4xi1> loc(#loc5)
|
||||
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc5)
|
||||
%10 = stablehlo.select %8, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc6)
|
||||
%11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%12 = stablehlo.compare EQ, %0#4, %11, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%13 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%14 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3xf32> loc(#loc5)
|
||||
%15 = stablehlo.select %13, %0#2, %14 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc6)
|
||||
%16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%17 = stablehlo.compare EQ, %0#4, %16, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%18 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3xf32> loc(#loc5)
|
||||
%20 = stablehlo.select %18, %0#3, %19 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc6)
|
||||
return %5, %10, %15, %20 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":831:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/tridiagonal"(#loc2))
|
||||
#loc4 = loc("jit(func)/jit(main)/eq"(#loc2))
|
||||
#loc5 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc2))
|
||||
#loc6 = loc("jit(func)/jit(main)/select_n"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.3\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xb7u/\x01-\x0f\x0f\x07\x17\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03I\x0f\x0b\x0b\x0b/Oo\x0f\x0b\x0b\x1b\x13\x0b\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03+\x17\x1b\x13\x07\x17\x13\x07\x07\x0f\x0f\x07\x07\x17#\x13\x13\x13\x13\x1b\x13\x17\x02\xe6\x04\x1d'\x07\x1d)\x07\x1f\x17%\xfe\x0c\x1b\x1d+\x07\x11\x03\x05\x03\x07\x0f\x11\x13\x0b\x15\x0b\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x1b\x05\x05\x1d\x03\x03\x1f[\x05\x1f\x1d#\x07\x05!\x05#\x05%\x05'\x05)\x1f'\x01\x1d+\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03=\r\x01#\x1f\x03\tCGKO\r\x03/E\x1d-\r\x03/I\x1d/\r\x03/M\x1d1\r\x03/Q\x1d3\x1d5\x1d7\x1f\x15\t\x00\x00\xc0\x7f\x1f\x17\t\x00\x00\x00\x00\r\x03]_\x1d9\x05\x03\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x039\x03\x03o\x15\x03\x01\x01\x01\x03\x0b9777s\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\t\r\x0b)\x07\t\x11\x11\x0b)\x03\t\x19\t)\x05\t\x11\x0b)\x03\t\x11\x01\x13)\x01\x0b)\x01\x19\x1b\x1d)\x05\t\r\x11\x11\x03\x07\t\x07\r\x05\x05)\x03\r\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\x1b)\x07\t\x11\x11\x11)\x03\x05\x1b)\x05\t\x11\x11\x04\x96\x03\x05\x01Q\x05\r\x01\x07\x04n\x03\x03\x01\x05\x0bP\x05\x03\x07\x04B\x03\x039c\x03\x0f\x19\x00\tB\x05\x05\x03\x15\tB\x05\x07\x03\x17\rG!\x1d\t\x0b\x07\r\x05\x05\t\x03\x01\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f\x11\x03F\x03\x0f\x03)\x03\x13\x03F\x03\x0b\x03\x07\x03\x03\x07\x06\t\x03\x07\x07\x15\x07\x17\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f\x1b\x03F\x03\x0f\x03-\x03\x1d\x03F\x03\x0b\x03\r\x03\x03\x07\x06\t\x03\r\x07\x1f\t!\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f%\x03F\x03\x0f\x03\x1d\x03'\x03F\x03\x0b\x03\x05\x03\x03\x07\x06\t\x03\x05\x07)\x0b+\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f/\x03F\x03\x0f\x03\x1d\x031\x03F\x03\x0b\x03\x05\x03\x03\x07\x06\t\x03\x05\x073\r5\x0f\x04\x05\t\x19#-7\x06\x03\x01\x05\x01\x00z\x07?'\x03\r\x0f\x0b\t\t\t\t!;K/iA)\x05\x13%)9\x15\x1f\x11\x19\x15\x17)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00compare_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00jit(func)/jit(main)/select_n\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00lower\x00\x00cusolver_sytrd_ffi\x00\x08=\x11\x05/\x01\x0b;?ASU\x03W\x03Y\x11acegikmq\x03-\x0513\x035",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2025_01_09["f64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_sytrd_ffi'],
|
||||
serialized_date=datetime.date(2025, 1, 9),
|
||||
inputs=(array([[[-3.1253059079430567e+00, 5.7973737786717416e+00,
|
||||
1.2965480927655553e+00, 6.4817268627909046e+00],
|
||||
[-1.1936619933620258e+00, 1.2510994308173569e+00,
|
||||
1.5304146964685636e-01, -3.9365359484492188e+00],
|
||||
[-4.7758906372514920e+00, -8.3938964119717618e-01,
|
||||
9.7688102177825908e-01, -1.0925284023457876e-01],
|
||||
[ 1.8645139381221081e+00, -1.2926082118180000e+00,
|
||||
2.1621899240570671e-03, -3.9399456873538158e-02]],
|
||||
|
||||
[[ 1.8036782176339347e-01, -2.7779317198395299e+00,
|
||||
1.3453106407991833e+00, 5.9722843958961107e+00],
|
||||
[ 2.8866256663925954e-01, -1.4318659824435502e+00,
|
||||
6.4080201583537226e+00, 1.1943604455356474e+00],
|
||||
[ 2.6497910527857116e-01, -1.4861093982154459e+00,
|
||||
5.4248701360547784e+00, -1.5973136783857038e+00],
|
||||
[ 6.3417141892397311e+00, -3.7002383133493355e+00,
|
||||
-1.8965912710753372e+00, -2.4871387323298411e+00]]]),),
|
||||
expected_outputs=(array([[[-3.1253059079430567 , 5.797373778671742 ,
|
||||
1.2965480927655553 , 6.481726862790905 ],
|
||||
[ 5.264064262415029 , 0.7243571624190226 ,
|
||||
0.15304146964685636, -3.936535948449219 ],
|
||||
[ 0.7395622620235721 , 0.18927242291294558,
|
||||
1.2785065803968352 , -0.10925284023457876],
|
||||
[-0.28872607234692266, -0.20307594403518822,
|
||||
-1.58216916428174 , 0.18571725290621857]],
|
||||
|
||||
[[ 0.18036782176339347, -2.77793171983953 ,
|
||||
1.3453106407991833 , 5.972284395896111 ],
|
||||
[-6.353808217251881 , -2.9702952281045247 ,
|
||||
6.408020158353723 , 1.1943604455356474 ],
|
||||
[ 0.03989164783700368, -4.028630349114961 ,
|
||||
-0.9432585642079134 , -1.5973136783857038 ],
|
||||
[ 0.9547221802795454 , -0.6834393717622312 ,
|
||||
-1.599653328439445 , 5.419419213593815 ]]]), array([[-3.1253059079430567 , 0.7243571624190226 , 1.2785065803968352 ,
|
||||
0.18571725290621857],
|
||||
[ 0.18036782176339347, -2.9702952281045247 , -0.9432585642079134 ,
|
||||
5.419419213593815 ]]), array([[ 5.264064262415029 , 0.18927242291294558, -1.58216916428174 ],
|
||||
[-6.353808217251881 , -4.028630349114961 , -1.599653328439445 ]]), array([[1.2267567289944903, 1.9207870511685834, 0. ],
|
||||
[1.0454314258109778, 1.3632434630444663, 0. ]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("x")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xf64> loc("x")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x3xf64> {jax.result_info = "[2]"}, tensor<2x3xf64> {jax.result_info = "[3]"}) {
|
||||
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_sytrd_ffi(%arg0) {mhlo.backend_config = {lower = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x4x4xi1> loc(#loc5)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc5)
|
||||
%5 = stablehlo.select %3, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc6)
|
||||
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x4xi1> loc(#loc5)
|
||||
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc5)
|
||||
%10 = stablehlo.select %8, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc6)
|
||||
%11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%12 = stablehlo.compare EQ, %0#4, %11, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%13 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%14 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x3xf64> loc(#loc5)
|
||||
%15 = stablehlo.select %13, %0#2, %14 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc6)
|
||||
%16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%17 = stablehlo.compare EQ, %0#4, %16, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%18 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x3xf64> loc(#loc5)
|
||||
%20 = stablehlo.select %18, %0#3, %19 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc6)
|
||||
return %5, %10, %15, %20 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":831:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/tridiagonal"(#loc2))
|
||||
#loc4 = loc("jit(func)/jit(main)/eq"(#loc2))
|
||||
#loc5 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc2))
|
||||
#loc6 = loc("jit(func)/jit(main)/select_n"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.3\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xb7u/\x01-\x0f\x0f\x07\x17\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03I\x0f\x0b\x0b\x0b/Oo\x0f\x0b\x0b\x1b\x13\x0b\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x1f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03+\x17\x1b\x13\x07\x17\x13\x07\x07\x0f\x0f\x07\x07\x17#\x13\x13\x13\x13\x1b\x13\x17\x02\xf6\x04\x1d'\x07\x1d)\x07\x1f\x17%\xfe\x0c\x1b\x1d+\x07\x11\x03\x05\x03\x07\x0f\x11\x13\x0b\x15\x0b\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x1b\x05\x05\x1d\x03\x03\x1f[\x05\x1f\x1d#\x07\x05!\x05#\x05%\x05'\x05)\x1f'\x01\x1d+\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03=\r\x01#\x1f\x03\tCGKO\r\x03/E\x1d-\r\x03/I\x1d/\r\x03/M\x1d1\r\x03/Q\x1d3\x1d5\x1d7\x1f\x15\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17\t\x00\x00\x00\x00\r\x03]_\x1d9\x05\x03\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x039\x03\x03o\x15\x03\x01\x01\x01\x03\x0b9777s\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\t\r\x0b)\x07\t\x11\x11\x0b)\x03\t\x19\x0b)\x05\t\x11\x0b)\x03\t\x11\x01\x13)\x01\x0b)\x01\x19\x1b\x1d)\x05\t\r\x11\x11\x03\x07\t\x07\r\x05\x05)\x03\r\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\x1b)\x07\t\x11\x11\x11)\x03\x05\x1b)\x05\t\x11\x11\x04\x96\x03\x05\x01Q\x05\r\x01\x07\x04n\x03\x03\x01\x05\x0bP\x05\x03\x07\x04B\x03\x039c\x03\x0f\x19\x00\tB\x05\x05\x03\x15\tB\x05\x07\x03\x17\rG!\x1d\t\x0b\x07\r\x05\x05\t\x03\x01\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f\x11\x03F\x03\x0f\x03)\x03\x13\x03F\x03\x0b\x03\x07\x03\x03\x07\x06\t\x03\x07\x07\x15\x07\x17\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f\x1b\x03F\x03\x0f\x03-\x03\x1d\x03F\x03\x0b\x03\r\x03\x03\x07\x06\t\x03\r\x07\x1f\t!\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f%\x03F\x03\x0f\x03\x1d\x03'\x03F\x03\x0b\x03\x05\x03\x03\x07\x06\t\x03\x05\x07)\x0b+\x03F\x01\x0b\x03\t\x03\x05\x05F\x01\r\x03\x0f\x05\x0f/\x03F\x03\x0f\x03\x1d\x031\x03F\x03\x0b\x03\x05\x03\x03\x07\x06\t\x03\x05\x073\r5\x0f\x04\x05\t\x19#-7\x06\x03\x01\x05\x01\x00z\x07?'\x03\r\x0f\x0b\t\t\t\t!;K/iA)\x05\x13%)9\x15\x1f\x11\x19\x15\x17)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00compare_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00jit(func)/jit(main)/select_n\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00lower\x00\x00cusolver_sytrd_ffi\x00\x08=\x11\x05/\x01\x0b;?ASU\x03W\x03Y\x11acegikmq\x03-\x0513\x035",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2025_01_09["c64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_sytrd_ffi'],
|
||||
serialized_date=datetime.date(2025, 1, 9),
|
||||
inputs=(array([[[ 3.1812036 -0.48341316j , 0.11142776-6.424985j ,
|
||||
1.6697654 -0.7869761j , 2.3984313 +0.033103157j],
|
||||
[-4.391181 +3.711728j , 1.062603 -1.5042974j ,
|
||||
3.2630582 -3.0468545j , -1.1403834 +4.834375j ],
|
||||
[-3.3597422 +6.9571705j , 3.0927958 -3.3640988j ,
|
||||
-0.55644 +0.06318861j , -3.0856206 +2.0522592j ],
|
||||
[-0.12403099-0.7970628j , -0.56212497-2.0354905j ,
|
||||
-1.429683 +0.91537094j , 3.375416 +0.4935444j ]],
|
||||
|
||||
[[-1.8590846 -0.39196187j , 0.38138282+3.0784519j ,
|
||||
1.966164 -0.5697291j , 0.21431345+0.66547275j ],
|
||||
[-3.575784 -2.2677388j , -3.523877 -0.12984382j ,
|
||||
1.5619129 +4.8201146j , -1.321369 -1.2919399j ],
|
||||
[-1.3055397 +5.412672j , -2.538718 -1.8749187j ,
|
||||
1.4178271 +0.8454019j , 1.6860713 +3.5848062j ],
|
||||
[-3.2620752 -5.5018277j , 1.5509791 +2.9875515j ,
|
||||
-1.6966074 +1.4490732j , -2.3220365 +0.102485105j]]],
|
||||
dtype=complex64),),
|
||||
expected_outputs=(array([[[ 3.1812036e+00+0.0000000e+00j, 1.1142776e-01-6.4249849e+00j,
|
||||
1.6697654e+00-7.8697610e-01j, 2.3984313e+00+3.3103157e-02j],
|
||||
[ 9.6643763e+00+0.0000000e+00j, 4.1165028e+00+0.0000000e+00j,
|
||||
3.2630582e+00-3.0468545e+00j, -1.1403834e+00+4.8343749e+00j],
|
||||
[ 3.4564066e-01-4.0370134e-01j, -2.6924424e+00+0.0000000e+00j,
|
||||
-3.6695964e+00+0.0000000e+00j, -3.0856206e+00+2.0522592e+00j],
|
||||
[-5.7498864e-03+5.5189621e-02j, 1.9120899e-01-4.7203872e-02j,
|
||||
2.5072362e+00+0.0000000e+00j, 3.4346726e+00+2.1266517e-09j]],
|
||||
|
||||
[[-1.8590846e+00+0.0000000e+00j, 3.8138282e-01+3.0784519e+00j,
|
||||
1.9661640e+00-5.6972909e-01j, 2.1431345e-01+6.6547275e-01j],
|
||||
[ 9.4784794e+00+0.0000000e+00j, 3.5050194e+00+0.0000000e+00j,
|
||||
1.5619129e+00+4.8201146e+00j, -1.3213691e+00-1.2919399e+00j],
|
||||
[ 2.7161252e-02-4.1934705e-01j, -2.4255285e+00+0.0000000e+00j,
|
||||
-2.9377868e+00+0.0000000e+00j, 1.6860713e+00+3.5848062e+00j],
|
||||
[ 3.1363532e-01+3.6697471e-01j, 5.7014298e-01-8.8611171e-03j,
|
||||
-2.7132864e+00+0.0000000e+00j, -4.9953189e+00-4.8905555e-09j]]],
|
||||
dtype=complex64), array([[ 3.1812036, 4.116503 , -3.6695964, 3.4346726],
|
||||
[-1.8590846, 3.5050194, -2.9377868, -4.995319 ]], dtype=float32), array([[ 9.664376 , -2.6924424, 2.5072362],
|
||||
[ 9.478479 , -2.4255285, -2.7132864]], dtype=float32), array([[1.4543678-0.3840629j , 1.8027349-0.47009134j,
|
||||
1.9218173+0.38762447j],
|
||||
[1.3772529+0.23925133j, 1.2086039+0.6028182j ,
|
||||
1.9929025-0.11893065j]], dtype=complex64)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("x")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xcomplex<f32>> loc("x")) -> (tensor<2x4x4xcomplex<f32>> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x3xf32> {jax.result_info = "[2]"}, tensor<2x3xcomplex<f32>> {jax.result_info = "[3]"}) {
|
||||
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_sytrd_ffi(%arg0) {mhlo.backend_config = {lower = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f32>>) -> (tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex<f32>>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x4x4xi1> loc(#loc5)
|
||||
%4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc5)
|
||||
%5 = stablehlo.select %3, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc6)
|
||||
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x4xi1> loc(#loc5)
|
||||
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc5)
|
||||
%10 = stablehlo.select %8, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc6)
|
||||
%11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%12 = stablehlo.compare EQ, %0#4, %11, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%13 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%14 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3xf32> loc(#loc5)
|
||||
%15 = stablehlo.select %13, %0#2, %14 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc6)
|
||||
%16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%17 = stablehlo.compare EQ, %0#4, %16, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%18 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%19 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<complex<f32>>) -> tensor<2x3xcomplex<f32>> loc(#loc5)
|
||||
%20 = stablehlo.select %18, %0#3, %19 : tensor<2x3xi1>, tensor<2x3xcomplex<f32>> loc(#loc6)
|
||||
return %5, %10, %15, %20 : tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex<f32>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":831:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/tridiagonal"(#loc2))
|
||||
#loc4 = loc("jit(func)/jit(main)/eq"(#loc2))
|
||||
#loc5 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc2))
|
||||
#loc6 = loc("jit(func)/jit(main)/select_n"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.3\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xbfw5\x01-\x0f\x0f\x07\x17\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03K\x0f\x0b\x0b\x0b/Oo\x0f\x0b\x0b\x1b\x13\x0b\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f/\x1f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x031\x1b\x13\x07\x17\x17\x17\x13\x07\x0b\x07\x0f\x0f\x0f\x07\x07\x17#\x13\x13\x13\x13\x1b\x13\x17\x02:\x05\x1d'\x07\x1d)\x07\x1f\x17%\xfe\x0c\x1b\x1d+\x07\x11\x03\x05\x03\x07\x0f\x11\x13\x0b\x15\x0b\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x1b\x05\x05\x1d\x03\x03\x1f]\x05\x1f\x1d#\x07\x05!\x05#\x05%\x05'\x05)\x1f-\x01\x1d+\t\x07\x07\x01\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03=\r\x01#%\x03\tCGKO\r\x03/E\x1d-\r\x03/I\x1d/\r\x03/M\x1d1\r\x03/Q\x1d3\x1d5\x1d7\x1f\x19\t\x00\x00\xc0\x7f\x1f\x1b\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1d\t\x00\x00\x00\x00\r\x03_a\x1d9\x05\x03\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x039\x03\x03q\x15\x03\x01\x01\x01\x03\x0b9777u\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x15)\x03\t\x1f\t)\x05\t\x11\t)\x05\t\r\t)\x05\t\r\x15)\x03\t\x13\x01\x03\t\x13)\x01\t)\x01\x15)\x01\x1f\x1b\x1d)\x05\t\r\x13\x11\x03\x05\t\x05\x0b\r\x0f)\x03\r\x17)\x03\t\x17)\x03\x05\x17)\x03\x01!)\x07\t\x11\x11\x13)\x03\x05!)\x05\t\x11\x13\x04\xae\x03\x05\x01Q\x05\r\x01\x07\x04\x86\x03\x03\x01\x05\x0bP\x05\x03\x07\x04Z\x03\x03;g\x03\x0b\x19\x00\tB\x05\x05\x03\x19\tB\x05\x07\x03\x1b\tB\x05\t\x03\x1d\rG!\x1d\x0b\x0b\x05\x0b\r\x0f\x07\x03\x01\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x11\x13\x03F\x03\x11\x03/\x03\x15\x03F\x03\r\x03\x05\x03\x05\x07\x06\t\x03\x05\x07\x17\t\x19\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x11\x1d\x03F\x03\x11\x033\x03\x1f\x03F\x03\r\x03\x0b\x03\x03\x07\x06\t\x03\x0b\x07!\x0b#\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x11'\x03F\x03\x11\x03#\x03)\x03F\x03\r\x03\r\x03\x03\x07\x06\t\x03\r\x07+\r-\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x111\x03F\x03\x11\x03#\x033\x03F\x03\r\x03\x0f\x03\x05\x07\x06\t\x03\x0f\x075\x0f7\x0f\x04\x05\t\x1b%/9\x06\x03\x01\x05\x01\x00z\x07?'\x03\r\x0f\x0b\t\t\t\t!;K/iA)\x05\x13%)9\x15\x1f\x11\x19\x15\x17)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00compare_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00jit(func)/jit(main)/select_n\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00lower\x00\x00cusolver_sytrd_ffi\x00\x08A\x13\x05/\x01\x0b;?ASU\x03W\x03Y\x03[\x11cegikmos\x03-\x0513\x035",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2025_01_09["c128"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_sytrd_ffi'],
|
||||
serialized_date=datetime.date(2025, 1, 9),
|
||||
inputs=(array([[[-8.811619206101685 -6.346103055817515j ,
|
||||
-2.0673625472346933 -0.27575832037940934j,
|
||||
0.17956803185065526+0.4305341577966682j ,
|
||||
1.1602180635956165 +0.9004656831972404j ],
|
||||
[ 0.04561867504808473-4.771560375289864j ,
|
||||
4.778928872206952 +1.7033875050911593j ,
|
||||
-3.8122399421775315 -5.795050026499773j ,
|
||||
0.5791695386369239 +0.9628332000212088j ],
|
||||
[-1.9993689229562661 -4.064915227679414j ,
|
||||
-2.3979548160862456 +0.8988248576397159j ,
|
||||
-2.7272138752203796 +0.21666059983248281j,
|
||||
2.067134647794104 -0.14136984404543293j],
|
||||
[ 5.485533829522301 +1.1676501594179898j ,
|
||||
-3.009376295256183 +1.2776833329344761j ,
|
||||
-1.0223678823537847 -1.2013281507542595j ,
|
||||
0.08734463104455259-2.420057264345883j ]],
|
||||
|
||||
[[ 2.212505188041518 -3.2177849380091894j ,
|
||||
-1.0593623525670381 +5.37426459023754j ,
|
||||
-3.1623822399052592 +0.8649520447907271j ,
|
||||
-1.7388123297543543 +3.2870543762142654j ],
|
||||
[ 2.793400596418774 -0.739715064363405j ,
|
||||
8.739016966775436 +5.022683652842428j ,
|
||||
0.7559458452282815 +4.110352398353875j ,
|
||||
1.7261562225678564 -3.8393940384377196j ],
|
||||
[ 0.37042546632316864-0.1996809256297461j ,
|
||||
-1.3181774150972687 -0.32212366971341144j,
|
||||
1.3177595427368969 -1.9978296888748126j ,
|
||||
-0.6558210664615476 +0.08467474893379115j],
|
||||
[-1.7318040703079298 -4.599436443227647j ,
|
||||
-0.20128653725409817-1.115567638745429j ,
|
||||
-1.579624976405968 +7.379126520550004j ,
|
||||
2.0570113545830235 +2.5303646286484853j ]]]),),
|
||||
expected_outputs=(array([[[-8.811619206101685 +0.0000000000000000e+00j,
|
||||
-2.0673625472346933 -2.7575832037940934e-01j,
|
||||
0.17956803185065526 +4.3053415779666820e-01j,
|
||||
1.1602180635956165 +9.0046568319724041e-01j],
|
||||
[-8.645540449646582 +0.0000000000000000e+00j,
|
||||
0.3845716078748591 +0.0000000000000000e+00j,
|
||||
-3.8122399421775315 -5.7950500264997729e+00j,
|
||||
0.5791695386369239 +9.6283320002120876e-01j],
|
||||
[ 0.02053989913865161 -4.5643024157336864e-01j,
|
||||
5.036492728413716 +0.0000000000000000e+00j,
|
||||
1.84687034311969 +1.1102230246251565e-16j,
|
||||
2.067134647794104 -1.4136984404543293e-01j],
|
||||
[ 0.4283052472392046 +3.6949438614565994e-01j,
|
||||
0.21688243005340535 +5.1643916348653485e-01j,
|
||||
-2.7797703440504513 +0.0000000000000000e+00j,
|
||||
-0.09238232296342479 -6.5179386557032127e-17j]],
|
||||
|
||||
[[ 2.212505188041518 +0.0000000000000000e+00j,
|
||||
-1.0593623525670381 +5.3742645902375399e+00j,
|
||||
-3.1623822399052592 +8.6495204479072707e-01j,
|
||||
-1.7388123297543543 +3.2870543762142654e+00j],
|
||||
[-5.71675727138259 +0.0000000000000000e+00j,
|
||||
3.7004622018069355 +0.0000000000000000e+00j,
|
||||
0.7559458452282815 +4.1103523983538750e+00j,
|
||||
1.7261562225678564 -3.8393940384377196e+00j],
|
||||
[ 0.04522526729096931 -1.9532788545989693e-02j,
|
||||
-8.147585314007962 +0.0000000000000000e+00j,
|
||||
1.0622507332612754 +0.0000000000000000e+00j,
|
||||
-0.6558210664615476 +8.4674748933791150e-02j],
|
||||
[-0.155346841147821 -5.5396726066186741e-01j,
|
||||
-0.036849829076318265-2.2002105614858228e-01j,
|
||||
-0.6549744056215347 +0.0000000000000000e+00j,
|
||||
7.351074929027147 -9.3780430643058420e-18j]]]), array([[-8.811619206101685 , 0.3845716078748591 , 1.84687034311969 ,
|
||||
-0.09238232296342479],
|
||||
[ 2.212505188041518 , 3.7004622018069355 , 1.0622507332612754 ,
|
||||
7.351074929027147 ]]), array([[-8.645540449646582 , 5.036492728413716 , -2.7797703440504513],
|
||||
[-5.71675727138259 , -8.147585314007962 , -0.6549744056215347]]), array([[1.0052765556200653-0.5519100168555592j ,
|
||||
1.2933847199153174+0.5442027059704965j ,
|
||||
1.8134445415378648+0.5816424828382576j ],
|
||||
[1.4886337592820684-0.12939417037458123j,
|
||||
1.7898082606090093-0.454423909708186j ,
|
||||
1.8735743590190088-0.48669070184720786j]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("x")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xcomplex<f64>> loc("x")) -> (tensor<2x4x4xcomplex<f64>> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x3xf64> {jax.result_info = "[2]"}, tensor<2x3xcomplex<f64>> {jax.result_info = "[3]"}) {
|
||||
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_sytrd_ffi(%arg0) {mhlo.backend_config = {lower = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f64>>) -> (tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex<f64>>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x4x4xi1> loc(#loc5)
|
||||
%4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc5)
|
||||
%5 = stablehlo.select %3, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc6)
|
||||
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x4xi1> loc(#loc5)
|
||||
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc5)
|
||||
%10 = stablehlo.select %8, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc6)
|
||||
%11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%12 = stablehlo.compare EQ, %0#4, %11, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%13 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%14 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x3xf64> loc(#loc5)
|
||||
%15 = stablehlo.select %13, %0#2, %14 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc6)
|
||||
%16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc4)
|
||||
%17 = stablehlo.compare EQ, %0#4, %16, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc4)
|
||||
%18 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<2xi1>) -> tensor<2x3xi1> loc(#loc5)
|
||||
%19 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<complex<f64>>) -> tensor<2x3xcomplex<f64>> loc(#loc5)
|
||||
%20 = stablehlo.select %18, %0#3, %19 : tensor<2x3xi1>, tensor<2x3xcomplex<f64>> loc(#loc6)
|
||||
return %5, %10, %15, %20 : tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex<f64>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":831:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/tridiagonal"(#loc2))
|
||||
#loc4 = loc("jit(func)/jit(main)/eq"(#loc2))
|
||||
#loc5 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc2))
|
||||
#loc6 = loc("jit(func)/jit(main)/select_n"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.3\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xbfw5\x01-\x0f\x0f\x07\x17\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03K\x0f\x0b\x0b\x0b/Oo\x0f\x0b\x0b\x1b\x13\x0b\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/O\x1f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x031\x1b\x13\x07\x17\x17\x17\x13\x07\x0b\x07\x0f\x0f\x0f\x07\x07\x17#\x13\x13\x13\x13\x1b\x13\x17\x02j\x05\x1d'\x07\x1d)\x07\x1f\x17%\xfe\x0c\x1b\x1d+\x07\x11\x03\x05\x03\x07\x0f\x11\x13\x0b\x15\x0b\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x1b\x05\x05\x1d\x03\x03\x1f]\x05\x1f\x1d#\x07\x05!\x05#\x05%\x05'\x05)\x1f-\x01\x1d+\t\x07\x07\x01\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03=\r\x01#%\x03\tCGKO\r\x03/E\x1d-\r\x03/I\x1d/\r\x03/M\x1d1\r\x03/Q\x1d3\x1d5\x1d7\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d\t\x00\x00\x00\x00\r\x03_a\x1d9\x05\x03\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x039\x03\x03q\x15\x03\x01\x01\x01\x03\x0b9777u\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x15)\x03\t\x1f\x0b)\x05\t\x11\t)\x05\t\r\t)\x05\t\r\x15)\x03\t\x13\x01\x03\t\x13)\x01\t)\x01\x15)\x01\x1f\x1b\x1d)\x05\t\r\x13\x11\x03\x05\t\x05\x0b\r\x0f)\x03\r\x17)\x03\t\x17)\x03\x05\x17)\x03\x01!)\x07\t\x11\x11\x13)\x03\x05!)\x05\t\x11\x13\x04\xae\x03\x05\x01Q\x05\r\x01\x07\x04\x86\x03\x03\x01\x05\x0bP\x05\x03\x07\x04Z\x03\x03;g\x03\x0b\x19\x00\tB\x05\x05\x03\x19\tB\x05\x07\x03\x1b\tB\x05\t\x03\x1d\rG!\x1d\x0b\x0b\x05\x0b\r\x0f\x07\x03\x01\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x11\x13\x03F\x03\x11\x03/\x03\x15\x03F\x03\r\x03\x05\x03\x05\x07\x06\t\x03\x05\x07\x17\t\x19\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x11\x1d\x03F\x03\x11\x033\x03\x1f\x03F\x03\r\x03\x0b\x03\x03\x07\x06\t\x03\x0b\x07!\x0b#\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x11'\x03F\x03\x11\x03#\x03)\x03F\x03\r\x03\r\x03\x03\x07\x06\t\x03\r\x07+\r-\x03F\x01\r\x03\x07\x03\x07\x05F\x01\x0f\x03\x11\x05\x111\x03F\x03\x11\x03#\x033\x03F\x03\r\x03\x0f\x03\x05\x07\x06\t\x03\x0f\x075\x0f7\x0f\x04\x05\t\x1b%/9\x06\x03\x01\x05\x01\x00z\x07?'\x03\r\x0f\x0b\t\t\t\t!;K/iA)\x05\x13%)9\x15\x1f\x11\x19\x15\x17)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00compare_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00jit(func)/jit(main)/select_n\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00lower\x00\x00cusolver_sytrd_ffi\x00\x08A\x13\x05/\x01\x0b;?ASU\x03W\x03Y\x03[\x11cegikmos\x03-\x0513\x035",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
@ -2972,29 +2972,35 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||
|
||||
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||
|
||||
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower, platform):
|
||||
def _tridiagonal_cpu_hlo(ctx, a, *, lower):
|
||||
a_aval, = ctx.avals_in
|
||||
cpu_args = []
|
||||
if platform == "cpu":
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
|
||||
cpu_args.extend(ctx_args)
|
||||
a, d, e, taus, info = sytrd_impl(*cpu_args, a_aval.dtype, a, lower=lower)
|
||||
return a, d, e, taus, info
|
||||
return lapack.sytrd_hlo(ctx, a_aval.dtype, a, lower=lower)
|
||||
|
||||
def _tridiagonal_gpu_hlo(ctx, a, *, lower, target_name_prefix):
|
||||
operand_aval, = ctx.avals_in
|
||||
dims = operand_aval.shape
|
||||
batch_dims = dims[:-2]
|
||||
nb = len(batch_dims)
|
||||
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
|
||||
result_layouts = [layout, tuple(range(nb, -1, -1)), tuple(range(nb, -1, -1)),
|
||||
tuple(range(nb, -1, -1)), tuple(range(nb - 1, -1, -1))]
|
||||
rule = ffi.ffi_lowering(f"{target_name_prefix}solver_sytrd_ffi",
|
||||
operand_layouts=[layout],
|
||||
result_layouts=result_layouts,
|
||||
operand_output_aliases={0: 0})
|
||||
return rule(ctx, a, lower=lower)
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p,
|
||||
partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo, platform="cpu"),
|
||||
platform="cpu",
|
||||
)
|
||||
tridiagonal_p, _tridiagonal_cpu_hlo, platform="cpu")
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p,
|
||||
partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd, platform="cuda"),
|
||||
partial(_tridiagonal_gpu_hlo, target_name_prefix="cu"),
|
||||
platform="cuda",
|
||||
)
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p,
|
||||
partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd, platform="rocm"),
|
||||
partial(_tridiagonal_gpu_hlo, target_name_prefix="hip"),
|
||||
platform="rocm",
|
||||
)
|
||||
|
||||
|
@ -16,13 +16,12 @@ from functools import partial
|
||||
import importlib
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call, dense_int_array
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
@ -131,11 +130,6 @@ def has_magma():
|
||||
return _hiphybrid.has_magma()
|
||||
return False
|
||||
|
||||
def _real_type(dtype):
|
||||
"""Returns the real equivalent of 'dtype'."""
|
||||
return np.finfo(dtype).dtype
|
||||
|
||||
|
||||
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
|
||||
indices, indptr, b, tol, reorder):
|
||||
"""Sparse solver via QR decomposition. CUDA only."""
|
||||
@ -159,77 +153,3 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
|
||||
return out
|
||||
|
||||
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
|
||||
|
||||
|
||||
def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n, (m, n)
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
|
||||
lwork, opaque = gpu_solver.build_sytrd_descriptor(dtype, lower, b, n)
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
diag_type = a_type.element_type
|
||||
elif dtype == np.complex64:
|
||||
diag_type = ir.F32Type.get()
|
||||
elif dtype == np.complex128:
|
||||
diag_type = ir.F64Type.get()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
a, d, e, taus, info, _ = custom_call(
|
||||
f"{platform}solver_sytrd",
|
||||
result_types=[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
operands=[a],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0},
|
||||
).results
|
||||
# Workaround for NVIDIA partners bug #3865118: sytrd returns an incorrect "1"
|
||||
# in the first element of the superdiagonal in the `a` matrix in the
|
||||
# lower=False case. The correct result is returned in the `e` vector so we can
|
||||
# simply copy it back to where it needs to be:
|
||||
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
|
||||
if not lower and platform == "cu" and m > 1:
|
||||
start = (0,) * len(batch_dims) + (0,)
|
||||
end = batch_dims + (1,)
|
||||
s = hlo.slice(
|
||||
e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start)))
|
||||
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
|
||||
s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1)))
|
||||
# The diagonals are always real; convert to complex if needed.
|
||||
s = hlo.convert(
|
||||
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
|
||||
offsets = tuple(hlo.constant(intattr(i))
|
||||
for i in ((0,) * len(batch_dims) + (0, 1)))
|
||||
a = hlo.dynamic_update_slice(a, s, offsets)
|
||||
|
||||
return a, d, e, taus, info
|
||||
|
||||
cuda_sytrd = partial(_sytrd_hlo, "cu", _cusolver)
|
||||
rocm_sytrd = partial(_sytrd_hlo, "hip", _hipsolver)
|
||||
|
@ -50,6 +50,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cuda_threef
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_cusolver_sytrd
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
|
||||
@ -144,6 +145,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cuda_eigh_cusolver_syev.data_2024_09_30,
|
||||
cuda_svd_cusolver_gesvd.data_2024_10_08,
|
||||
cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09,
|
||||
cuda_tridiagonal_cusolver_sytrd.data_2025_01_09,
|
||||
rocm_qr_hipsolver_geqrf.data_2024_08_05,
|
||||
rocm_eigh_hipsolver_syev.data_2024_08_05,
|
||||
cpu_schur_lapack_gees.data_2023_07_16,
|
||||
@ -783,7 +785,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
@ -834,6 +835,25 @@ class CompatTest(bctu.CompatTestBase):
|
||||
)
|
||||
self.run_one_test(lax.linalg.tridiagonal_solve, data, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_gpu_tridiagonal_solver_sytrd(self, dtype_name):
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
def func(x):
|
||||
return lax.linalg.tridiagonal(x, lower=True)
|
||||
|
||||
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
||||
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
||||
|
||||
data = self.load_testdata(
|
||||
cuda_tridiagonal_cusolver_sytrd.data_2025_01_09[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
|
||||
def test_approx_top_k(self):
|
||||
def func():
|
||||
x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user