Activate the FFI implementation of SVD on GPU.

Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
This commit is contained in:
Dan Foreman-Mackey 2024-10-17 17:56:33 -07:00 committed by jax authors
parent 3e634d9530
commit 8361eb58e1
9 changed files with 1072 additions and 344 deletions

View File

@ -962,6 +962,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
# eigh on GPU
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
# svd on GPU
"cusolver_gesvd_ffi", "cusolver_gesvdj_ffi",
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
# lu on TPU
"LuDecomposition",
# ApproxTopK on TPU

View File

@ -0,0 +1,818 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# type: ignore
# ruff: noqa
import datetime
from numpy import array, float32, complex64
data_2024_10_08 = {"jacobi": {}, "qr": {}}
data_2024_10_08["jacobi"]["f32"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvdj_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[ 4.0477114 , -1.6619755 , 3.410165 , -0.3096872 ],
[ 1.7056831 , 0.17458144, 0.12622283, 1.8056804 ],
[-0.21965337, 1.7691472 , -0.8960539 , 1.1495945 ],
[ 1.4272052 , 3.2477028 , -1.5676605 , -1.7347798 ]],
[[-3.914366 , -2.8975718 , 1.5051024 , 1.7379241 ],
[-5.045383 , 1.4189544 , -0.61290324, 1.7456682 ],
[-4.823716 , 0.32512662, -1.9951408 , 3.632175 ],
[ 0.8716193 , -0.24001008, 1.5933404 , 1.0177776 ]]],
dtype=float32),),
expected_outputs=(array([[[ 0.9018128 , 0.31748173 , 0.090800084, -0.27873662 ],
[ 0.1856214 , 0.18972664 , -0.78390217 , 0.56128997 ],
[-0.25094122 , 0.20106053 , -0.55834264 , -0.7647598 ],
[-0.29884213 , 0.90707415 , 0.25594372 , 0.1496738 ]],
[[-0.45493636 , -0.8541334 , -0.20037504 , -0.15276858 ],
[-0.57877773 , 0.30431184 , -0.4479597 , 0.6097069 ],
[-0.6747176 , 0.29090276 , 0.57102525 , -0.3661442 ],
[ 0.052960183, -0.30532822 , 0.65811265 , 0.68619025 ]]],
dtype=float32), array([[5.974016 , 4.183989 , 2.6312675, 0.5687128],
[9.106636 , 3.995191 , 1.8342099, 1.6058134]], dtype=float32), array([[[ 0.60185665 , -0.48223644 , 0.6347658 , 0.047846846],
[ 0.6833445 , 0.6709123 , -0.1184354 , -0.26246953 ],
[-0.18304126 , -0.16886286 , 0.11772618 , -0.96131265 ],
[ 0.37054697 , -0.53741086 , -0.75444424 , -0.0685464 ]],
[[ 0.8786726 , 0.0290856 , 0.12085139 , -0.46095958 ],
[ 0.034706447, 0.76957005 , -0.635503 , -0.051896986],
[ 0.4708456 , -0.014900906, -0.06417345 , 0.8797526 ],
[-0.07095488 , 0.6377255 , 0.75987667 , 0.10420583 ]]],
dtype=float32)),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32> loc(#loc3)
%2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%3 = stablehlo.compare EQ, %0#4, %2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
%6 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%7 = stablehlo.select %6, %0#1, %5 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%11 = stablehlo.select %10, %0#2, %9 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%15 = stablehlo.select %14, %1, %13 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
return %11, %7, %15 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#'+\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x1f\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02\x12\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x1d\x15\x03\x05\x1f\x03\x03\x19M\x05!\x05#\x17\x1f\xba\n\x1b\x05%\x1d'\x1d)\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x1d+\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##\x1f\x03\x079=A\r\x05);!#\x1d-\r\x05)?!#\x1d/\r\x05)C!#\x1d1\x1d3\x1d5\x1f\x11\t\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\x00\x00\r\x05O-Q-\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x03%\x03\x03a\x15\x03\x01\x01\x01\x03\x0b%e%%g\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\t\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x03\r\x0b)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b\x04\xe2\x02\x05\x01Q\x03\x07\x01\x07\x04\xba\x02\x03\x01\x05\tP\x03\x03\x07\x04\x8e\x02\x03/O\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\rF\x01\x0b\x03\x05\x03\r\x03F\x01\r\x03\x17\x03\x05\x0fF\x01\x0f\x03)\x05\x0f\x13\x03F\x01\x11\x03+\x03\x15\x03F\x01\r\x03\t\x03\x03\x03F\x01\x13\x03/\x03\x17\x05\x06\x01\x03\t\x07\x1b\t\x19\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03\x1f\x05\x06\x01\x03\x05\x07#\x0b!\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03'\x05\x06\x01\x03\x05\x07+\x11)\x11\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xe6\x06?)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1b\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08I\x17\x05#\x01\x0b157EG\x03I\x03K\x11SUWY[]_c\x03i\x03'\x05km\x03+\x03o\x03/",
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["jacobi"]["f64"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvdj_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[-0.23389781892940775, -0.20108449262485306,
-0.5666115573270456 , -2.4789757754536694 ],
[-1.8555613779538866 , 2.9994543506103533 ,
1.087201973266891 , -1.012848084871914 ],
[-3.195786395215201 , -2.483536010628656 ,
-0.9470206294018368 , -0.4080455549606986 ],
[ 0.41241574420074084, -8.059831117309347 ,
-0.7882929058035465 , 2.8856408696497664 ]],
[[ 1.7959513084827519 , -6.006170699401665 ,
0.9365840264144545 , -2.8339481807478486 ],
[ 3.37344261819673 , -2.809695890050033 ,
0.9096330403907014 , -0.22091594063236752],
[-1.8994569852606458 , 0.8593563734072376 ,
3.3970755287480623 , 1.162324876281341 ],
[ 1.2730096750276905 , 3.1397664846677484 ,
-4.625276688205361 , -3.0618323131122303 ]]]),),
expected_outputs=(array([[[ 0.05538281351237727 , -0.34543122319556363 ,
-0.875138554715149 , 0.33427911101376073 ],
[ 0.35132368873545367 , -0.3583481459523727 ,
0.4466199100964102 , 0.7407353967047108 ],
[-0.2350018651348275 , -0.8649389618871873 ,
0.17010144637596905 , -0.40953658387679276 ],
[-0.904587493327162 , 0.06437754689044517 ,
0.07568793759482184 , 0.41454593771389403 ]],
[[-0.7655776828405998 , -0.3259951618597633 ,
-0.5487654998241592 , 0.08046360781859588 ],
[-0.4502679475178967 , -0.18435823314270097 ,
0.7918156001863894 , 0.3691867719894481 ],
[ 0.011780696812478626, 0.5608039473610089 ,
-0.23310619898043075 , 0.7943687102371408 ],
[ 0.4593591211209926 , -0.7384024166677944 ,
-0.13246124527990302 , 0.47561022634191513 ]]]), array([[9.502458469794536 , 4.039322683970868 , 2.3634256270400944,
0.7309141361765127],
[8.368723268857098 , 7.018568450518651 , 2.518568829019885 ,
1.4845675373399827]]), array([[[-0.030192917862639196, 0.9383993407299717 ,
0.13535558060292174 , -0.31650265690534246 ],
[ 0.8755039760059299 , 0.15444343242930395 ,
0.14222562381109818 , 0.4352147586063772 ],
[-0.4808404523560293 , 0.20440997041138523 ,
0.32185308272374646 , 0.7895692601131882 ],
[ 0.03706258337996993 , -0.23188459951757628 ,
0.9262080391966009 , -0.2949484116712019 ]],
[[-0.2785970537508859 , 0.8741728222022489 ,
-0.3837203651304145 , 0.10471026668126972 ],
[-0.45773005941421924 , 0.09111437041907021 ,
0.690652044371576 , 0.5524320028900885 ],
[ 0.7781161676166882 , 0.18065797801542668 ,
0.010754987459331431, 0.6014833787542634 ],
[ 0.32772260227744276 , 0.4414552564025818 ,
0.612897026615641 , -0.5675142177221092 ]]])),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xf64>) -> tensor<2x4x4xf64> loc(#loc3)
%2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%3 = stablehlo.compare EQ, %0#4, %2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
%6 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%7 = stablehlo.select %6, %0#1, %5 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%11 = stablehlo.select %10, %0#2, %9 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%15 = stablehlo.select %14, %1, %13 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
return %11, %7, %15 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b'ML\xefR\rStableHLO_v1.7.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02"\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x1d\x15\x03\x05\x1f\x03\x03\x19M\x05!\x05#\x17\x1f\xba\n\x1b\x05%\x1d\'\x1d)\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\'\x01\x1d+\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##\x1f\x03\x079=A\r\x05);!#\x1d-\r\x05)?!#\x1d/\r\x05)C!#\x1d1\x1d3\x1d5\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\t\x00\x00\x00\x00\r\x05O-Q-\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x03%\x03\x03a\x15\x03\x01\x01\x01\x03\x0b%e%%g\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\x0b\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x03\r\x0b)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b\x04\xe2\x02\x05\x01Q\x03\x07\x01\x07\x04\xba\x02\x03\x01\x05\tP\x03\x03\x07\x04\x8e\x02\x03/O\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\rF\x01\x0b\x03\x05\x03\r\x03F\x01\r\x03\x17\x03\x05\x0fF\x01\x0f\x03)\x05\x0f\x13\x03F\x01\x11\x03+\x03\x15\x03F\x01\r\x03\t\x03\x03\x03F\x01\x13\x03/\x03\x17\x05\x06\x01\x03\t\x07\x1b\t\x19\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03\x1f\x05\x06\x01\x03\x05\x07#\x0b!\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03\'\x05\x06\x01\x03\x05\x07+\x11)\x11\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xe6\x06?)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1b\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08I\x17\x05#\x01\x0b157EG\x03I\x03K\x11SUWY[]_c\x03i\x03\'\x05km\x03+\x03o\x03/',
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["jacobi"]["c64"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvdj_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[ 2.4649541 -5.8653884j , 5.5100183 +2.0214202j ,
-2.4541297 +1.862114j , 5.4709225 +4.409564j ],
[ 1.1091617 +2.325679j , -5.0506334 +5.5802264j ,
-1.7254959 -1.5569435j , 3.002013 -2.7583091j ],
[ 0.8154569 +5.66862j , -0.7711735 +1.8621845j ,
1.2456422 -1.1770611j , 0.03156909-0.22670403j],
[ 3.9012303 -2.0669405j , -2.7752936 -0.71004313j,
-1.0354352 -5.5713825j , 1.554125 +0.9581067j ]],
[[-2.077837 +6.2613506j , -2.0213401 -3.2755377j ,
-2.1061401 -0.4942127j , -5.098616 +3.4436114j ],
[ 3.6104472 +0.75928044j, 1.3155019 -3.6494553j ,
0.58335614-0.8751654j , -1.1484178 -4.0733714j ],
[-0.4858576 +4.38415j , -2.3318157 +2.744366j ,
-0.7987209 -0.23303579j, -2.6747904 +1.5206293j ],
[ 4.013358 +0.978174j , -2.707136 -0.29939318j,
-2.5241148 -1.44767j , -1.8191104 +0.26034543j]]],
dtype=complex64),),
expected_outputs=(array([[[-7.7744454e-01+0.09018909j , -2.1625498e-01-0.16622283j ,
-1.6743444e-01+0.5261705j , 8.9576304e-02+0.01171514j ],
[ 1.2709992e-01-0.38680157j , -3.4196457e-01+0.6261395j ,
1.3424198e-01+0.37196818j , -1.5489596e-01+0.38061285j ],
[ 2.6221693e-01-0.25399417j , 1.7437853e-01+0.053393736j ,
7.3115915e-02+0.51672214j , 2.4561302e-01-0.7076709j ],
[-5.0269447e-02-0.29305094j , -5.6865913e-01-0.24491112j ,
4.3156582e-01-0.28307965j , 5.0845784e-01-0.05768533j ]],
[[-1.5179910e-02+0.8092423j , -1.0622249e-01+0.09009284j ,
-5.3503644e-01-0.053553585j , 1.8145569e-01+0.058631808j ],
[-1.7091817e-01+0.0025223175j, 7.6241887e-01+0.2851526j ,
-3.0385127e-02-0.115484476j , 3.0117261e-01-0.45080122j ],
[ 1.7759532e-01+0.4123873j , -2.5654158e-01+0.052343685j ,
6.6419083e-01-0.32588708j , -1.6986026e-04-0.4271902j ],
[ 2.4243295e-01+0.23515853j , 4.7519770e-01-0.15375356j ,
3.5048491e-01-0.16253117j , 8.7767310e-02+0.6924699j ]]],
dtype=complex64), array([[13.7465725, 9.692749 , 6.012994 , 1.7378366],
[12.048737 , 7.9871097, 3.9069395, 1.7972969]], dtype=float32), array([[[-0.29246047 +0.582183j , -0.5259067 -0.27628872j ,
0.3469344 -0.15329583j , -0.19642796 -0.20042528j ],
[ 0.02593917 +0.33676162j , 0.55821204 +0.18806678j ,
0.20056807 +0.35542676j , -0.5978458 -0.12236632j ],
[ 0.46109042 -0.034899022j, 0.24078317 -0.19413127j ,
0.19841644 -0.3350929j , 0.17725058 -0.71234804j ],
[-0.48504043 -0.1111859j , 0.31423682 +0.32512894j ,
0.23620881 -0.69435585j , -0.04027982 +0.091500096j]],
[[ 0.6148358 +0.14274344j , -0.23763065 +0.3584556j ,
-0.13778959 +0.19841002j , 0.234248 +0.55083406j ],
[ 0.734292 -0.11843189j , -0.07720101 -0.4717579j ,
-0.051302787-0.19603764j , -0.16576114 -0.38695592j ],
[ 0.019250087+0.17437238j , -0.43637297 +0.6207004j ,
0.033975117-0.27825257j , 0.024782527-0.5606609j ],
[-0.060103483+0.11833174j , -0.074751794-0.072441235j,
-0.5370554 +0.7304642j , 0.07712744 -0.37893948j ]]],
dtype=complex64)),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xcomplex<f32>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f32>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f32>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
%cst_0 = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f32>>) -> (tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>>, tensor<2x4x4xcomplex<f32>>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xcomplex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
%2 = stablehlo.real %1 : (tensor<2x4x4xcomplex<f32>>) -> tensor<2x4x4xf32> loc(#loc3)
%3 = stablehlo.imag %1 : (tensor<2x4x4xcomplex<f32>>) -> tensor<2x4x4xf32> loc(#loc3)
%4 = stablehlo.negate %3 : tensor<2x4x4xf32> loc(#loc3)
%5 = stablehlo.complex %2, %4 : tensor<2x4x4xcomplex<f32>> loc(#loc3)
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%11 = stablehlo.select %10, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%15 = stablehlo.select %14, %0#2, %13 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
%16 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
%18 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%19 = stablehlo.select %18, %5, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
return %15, %11, %19 : tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01+\x05\x01\x05\x1b\x01\x03\x0b\x03\x19\x0f\x13\x17\x1b\x1f#'+/37;\x03\xbfs9\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x035\x1b\x07\x07\x17\x07\x07\x1b\x0b\x0f\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02j\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x1f\x11\x01\x00\x05!\x05#\x05%\x1d\x15\x03\x05'\x03\x03\x19O\x05)\x05+\x17\x1f\xba\n\x1b\x05-\x1d/\x1d1\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x1d3\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##%\x03\x079=A\r\x05);!#\x1d5\r\x05)?!#\x1d7\r\x05)C!#\x1d9\x1d;\x1d=\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x17\t\x00\x00\xc0\x7f\x1f\x19\t\x00\x00\x00\x00\r\x05Q-S-\x1d?\x1dA\x0b\x03\x1dC\x1dE\x03\x01\x05\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x13\x01\t)\x05\t\x11\t\x1d\x13)\x07\t\x11\x11\t\x03\t)\x01\x13)\x01\t)\x01\x1b\x1b)\x03\t\x1b)\x03\r\r)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\r)\x03\t\x07)\x05\t\x05\x07)\x03\x05\r)\x05\t\x11\x07)\x03\t\r\x04n\x03\x05\x01Q\x03\x07\x01\x07\x04F\x03\x03\x01\x05\tP\x03\x03\x07\x04\x1a\x03\x039c\x03\x0b\x13\x00\x05B\x03\x05\x03\x15\x05B\x03\x07\x03\x17\x05B\x03\t\x03\x19\x0bG\x01\x17\x0b\x0b\x05\x0b\x05\x05\x1d\x03\x01\rF\x01\r\x03\x05\x03\x0f\x0f\x06\x01\x03\x11\x03\x13\x11\x06\x01\x03\x11\x03\x13\x13\x06\x01\x03\x11\x03\x17\x15\x06\x01\x03\x05\x05\x15\x19\x03F\x01\x0f\x03\x1d\x03\x07\x17F\x01\x11\x03/\x05\x11\x1d\x03F\x01\x13\x031\x03\x1f\x03F\x01\x0f\x03\x0b\x03\x05\x03F\x01\x15\x035\x03!\x07\x06\x01\x03\x0b\x07%\x0b#\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x03)\x07\x06\x01\x03\x05\x07-\r+\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x031\x07\x06\x01\x03\x05\x075\x1b3\x19\x04\x03\x07/'7\x06\x03\x01\x05\x01\x00\x8a\x07G)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x17\x15\x11\x11\x1b\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08M\x19\x05#\x01\x0b157EG\x03I\x03K\x03M\x11UWY[]_ae\x03k\x03'\x05mo\x03+\x03q\x03/",
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["jacobi"]["c128"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvdj_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[ 3.796399075115019 +0.6837988589198791j ,
-3.4038861521636834 -0.8791658549931785j ,
5.587025360536978 +1.3741531770045181j ,
2.546764576857623 +1.6809085185078099j ],
[-0.29341875154987235-0.7960296026166453j ,
-3.978920434463012 -3.064639731239574j ,
7.797646929978315 +0.8348826539375432j ,
1.3257327511547885 +4.524484078738454j ],
[ 2.664074451502439 +3.393033766292215j ,
-1.9919844260791377 -0.6279428424058938j ,
4.9406207893724465 -1.3141491766038624j ,
-0.09537336365814258-1.6177405195744095j ],
[-2.4688921567972058 +0.9746213770706899j ,
0.5921515121270726 +4.164182480017167j ,
-2.2589950508277568 +2.6432862086413222j ,
-2.556559707542412 +8.972441493886869j ]],
[[-1.1112549656532051 +1.9704765658574988j ,
8.989616960518267 -3.6609418049818268j ,
1.2346549691110358 -0.24414962907971388j,
-2.9090211908941734 +4.153707428606018j ],
[-0.5588638014996116 -1.67573324865607j ,
3.149773407631379 +4.604525381223088j ,
1.128102476802749 -1.3129091142659617j ,
-3.9361491309229013 -0.4879709640370399j ],
[-1.7318669937731295 -2.869679468975394j ,
-3.0523599142955913 +8.95268082648917j ,
-2.6723736459369936 -2.1677507057699845j ,
2.2856025738573873 +4.128295675578359j ],
[-3.0608685537602445 -0.21903335326027606j,
-4.833657640993765 -2.184980999873441j ,
-1.4399110875167116 +0.23459254553652992j,
1.9463909302306492 -6.990119453626141j ]]]),),
expected_outputs=(array([[[ 0.5058374100694445 +0.08736973325822138j ,
-0.15357420292088966 +0.15092672888666417j ,
0.4000099867967776 +0.03343224253498994j ,
-0.08606623675966052 -0.7222174392115995j ],
[ 0.5494982003369341 +0.21627530772570236j ,
-0.24559859535168804 +0.42796239936478525j ,
-0.35789689606765984 -0.3818532162746145j ,
0.1870806116265344 +0.3144916716626133j ],
[ 0.35367519290177574 -0.09104684890511797j ,
-0.2722754551548636 -0.01520202737842482j ,
0.20564615603913944 +0.6788851354546692j ,
-0.26426814455261194 +0.46823742188539086j ],
[-0.26578682435990086 +0.42866473683552114j ,
0.3753385929542211 +0.7035065862655021j ,
0.11671177248101379 +0.21948854675620316j ,
-0.21652589234437378 +0.03351132737269058j ]],
[[ 0.6447398173517832 -0.15549362721502685j ,
0.15703159957366286 -0.18175797380505443j ,
-0.26984447689730545 -0.021902292069557842j,
-0.5809856180118731 +0.30265058246070115j ],
[ 0.1676295953197978 +0.35517400244986735j ,
0.1471535132890569 -0.11945200848513432j ,
-0.048587988684360706-0.5959669737080945j ,
0.4658299942602748 +0.485070920592558j ],
[-0.18394077259047054 +0.4607473838424225j ,
-0.6953163697460342 -0.28046734855011674j ,
-0.31852936346663746 +0.1408757043223249j ,
-0.18675394341953266 +0.18859188206317357j ],
[-0.38736135459868454 -0.0985538837287597j ,
0.1400965142945759 +0.5697616659026354j ,
-0.6360840273269505 -0.20798320176325613j ,
-0.16572122692263513 +0.14373411782489567j ]]]), array([[14.800219047973494 , 10.208626444252278 , 5.121442071722992 ,
2.3317857898198886],
[16.387010501961203 , 10.85923071644345 , 4.434400577803048 ,
0.7913357405499906]]), array([[[ 0.22661733590854113 +0.12716811836891184j ,
-0.24780255554134717 -0.18478518062614943j ,
0.7440472412639979 -0.05201706404993294j ,
0.5257597465394888 +0.0646977585249859j ],
[-0.1730295430553622 +0.08448133745769008j ,
0.3682599006560183 +0.4301598551539888j ,
-0.2470430370452839 -0.1549815154245142j ,
0.6735915557971368 +0.3217074899922104j ],
[ 0.9230886222728419 -0.02650327461707979j ,
0.26368770031726146 +0.17940633187899951j ,
-0.07583142147304134 +0.04327009313242354j ,
-0.11210558432341579 +0.15904958085235385j ],
[ 0.13986036936906365 +0.15179124656114312j ,
-0.2301030912215956 -0.6550803059657956j ,
-0.4696876553365799 -0.3611221351860597j ,
0.15426668090371212 +0.31702832012208193j ]],
[[-0.09203102957042177 +0.1296290046576001j ,
0.9338316046889822 -0.07199509757761581j ,
0.035650554858596174+0.04949401150838972j ,
-0.11826036178363568 +0.2824813592971463j ],
[ 0.09583662270821797 +0.27782657909919606j ,
-0.029480552998465036-0.2320825865452519j ,
0.27250121923986637 +0.16010847179743545j ,
-0.7541782211683652 -0.4361420378707925j ],
[ 0.77179564007941 -0.03313538642869378j ,
0.11720446778594507 +0.18064177478753635j ,
0.40879934622962677 +0.326378840223041j ,
0.2808442732051157 -0.06596695716827156j ],
[ 0.5393502219452392 +0.0262529741146748j ,
0.14580725795066302 -0.020389102179641145j,
-0.6823399052632065 -0.3964340353663506j ,
-0.12461789414204025 -0.2201348160267379j ]]])),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xcomplex<f64>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f64>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f64>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
%cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f64>>) -> (tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>>, tensor<2x4x4xcomplex<f64>>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xcomplex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
%2 = stablehlo.real %1 : (tensor<2x4x4xcomplex<f64>>) -> tensor<2x4x4xf64> loc(#loc3)
%3 = stablehlo.imag %1 : (tensor<2x4x4xcomplex<f64>>) -> tensor<2x4x4xf64> loc(#loc3)
%4 = stablehlo.negate %3 : tensor<2x4x4xf64> loc(#loc3)
%5 = stablehlo.complex %2, %4 : tensor<2x4x4xcomplex<f64>> loc(#loc3)
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%11 = stablehlo.select %10, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%15 = stablehlo.select %14, %0#2, %13 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
%16 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
%18 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%19 = stablehlo.select %18, %5, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
return %15, %11, %19 : tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01+\x05\x01\x05\x1b\x01\x03\x0b\x03\x19\x0f\x13\x17\x1b\x1f#'+/37;\x03\xbfs9\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0bO/\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x035\x1b\x07\x07\x17\x07\x07\x1b\x0b\x0f\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02\x9a\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x1f\x11\x01\x00\x05!\x05#\x05%\x1d\x15\x03\x05'\x03\x03\x19O\x05)\x05+\x17\x1f\xba\n\x1b\x05-\x1d/\x1d1\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x1d3\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##%\x03\x079=A\r\x05);!#\x1d5\r\x05)?!#\x1d7\r\x05)C!#\x1d9\x1d;\x1d=\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x19\t\x00\x00\x00\x00\r\x05Q-S-\x1d?\x1dA\x0b\x03\x1dC\x1dE\x03\x01\x05\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x13\x01\x0b)\x05\t\x11\t\x1d\x13)\x07\t\x11\x11\t\x03\t)\x01\x13)\x01\t)\x01\x1b\x1b)\x03\t\x1b)\x03\r\r)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\r)\x03\t\x07)\x05\t\x05\x07)\x03\x05\r)\x05\t\x11\x07)\x03\t\r\x04n\x03\x05\x01Q\x03\x07\x01\x07\x04F\x03\x03\x01\x05\tP\x03\x03\x07\x04\x1a\x03\x039c\x03\x0b\x13\x00\x05B\x03\x05\x03\x15\x05B\x03\x07\x03\x17\x05B\x03\t\x03\x19\x0bG\x01\x17\x0b\x0b\x05\x0b\x05\x05\x1d\x03\x01\rF\x01\r\x03\x05\x03\x0f\x0f\x06\x01\x03\x11\x03\x13\x11\x06\x01\x03\x11\x03\x13\x13\x06\x01\x03\x11\x03\x17\x15\x06\x01\x03\x05\x05\x15\x19\x03F\x01\x0f\x03\x1d\x03\x07\x17F\x01\x11\x03/\x05\x11\x1d\x03F\x01\x13\x031\x03\x1f\x03F\x01\x0f\x03\x0b\x03\x05\x03F\x01\x15\x035\x03!\x07\x06\x01\x03\x0b\x07%\x0b#\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x03)\x07\x06\x01\x03\x05\x07-\r+\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x031\x07\x06\x01\x03\x05\x075\x1b3\x19\x04\x03\x07/'7\x06\x03\x01\x05\x01\x00\x8a\x07G)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x17\x15\x11\x11\x1b\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08M\x19\x05#\x01\x0b157EG\x03I\x03K\x03M\x11UWY[]_ae\x03k\x03'\x05mo\x03+\x03q\x03/",
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["qr"]["f32"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvd_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[ 7.064613 , 4.4742312 , -0.12700312, -0.71483076],
[-0.59317935, 4.0224333 , 0.5515773 , 6.009665 ],
[-5.193879 , -1.0297644 , -4.388829 , 2.3485358 ],
[-0.8724199 , -1.5610907 , 0.47096923, 0.10478485]],
[[ 0.7020009 , 0.0506321 , 2.5788887 , -0.44895908],
[ 2.617715 , -1.5580447 , -0.9952533 , 1.3444504 ],
[ 0.7077899 , -0.9494638 , 2.3607216 , -4.8069396 ],
[-0.9731158 , 2.762172 , -3.6125846 , -0.59783787]]],
dtype=float32),),
expected_outputs=(array([[[-0.7763474 , -0.1832199 , -0.54734766 , 0.25323072 ],
[-0.025878029, -0.93325573 , 0.35778683 , 0.018767256],
[ 0.6168491 , -0.29107273 , -0.72102463 , 0.12205475 ],
[ 0.12693313 , 0.10363644 , 0.22917844 , 0.959492 ]],
[[-0.35366213 , -0.12650901 , 0.30348226 , 0.8756809 ],
[ 0.08954996 , -0.56219155 , -0.7897131 , 0.22863595 ],
[-0.7685047 , 0.45641905 , -0.4388071 , -0.09236202 ],
[ 0.5256468 , 0.6779513 , -0.30281967 , 0.41518393 ]]],
dtype=float32), array([[10.327013 , 7.659154 , 3.736682 , 0.61586076],
[ 6.3917613 , 4.5939665 , 2.8317585 , 1.2306355 ]],
dtype=float32), array([[[-0.8505677 , -0.42713356 , -0.24819751 , 0.18024875 ],
[ 0.088860154, -0.5791473 , 0.108991824, -0.8030028 ],
[-0.14292236 , -0.16727918 , 0.94716436 , 0.23338944 ],
[ 0.49820864 , -0.6739162 , -0.17145996 , 0.51790595 ]],
[[-0.16729502 , 0.31668338 , -0.7375667 , 0.5724681 ],
[-0.41296393 , 0.5025677 , -0.24780482 , -0.7179688 ],
[-0.6604038 , 0.29167938 , 0.5744389 , 0.38575882 ],
[ 0.6044337 , 0.7497071 , 0.25418177 , 0.08939337 ]]],
dtype=float32)),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
return %10, %6, %14 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x1f\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xb6\x05\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19O\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x01\x1d)\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##\x1d\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x11\t\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\x00\x00\r\x07Q-S-U/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\t\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xc2\x02\x05\x01Q\x03\x07\x01\x07\x04\x9a\x02\x03\x01\x05\tP\x03\x03\x07\x04n\x02\x03-K\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\x03F\x01\x0b\x03\x17\x03\x05\rF\x01\r\x03'\x05\x0f\x11\x03F\x01\x0f\x03)\x03\x13\x03F\x01\x0b\x03\t\x03\x03\x03F\x01\x11\x03-\x03\x15\x05\x06\x01\x03\t\x07\x19\t\x17\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03\x1d\x05\x06\x01\x03\x05\x07!\x0b\x1f\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03%\x05\x06\x01\x03\x05\x07)\r'\x0f\x04\x03\x07#\x1b+\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08E\x15\x05#\x01\x0b379GI\x03K\x03M\x11WY[]/_ae\x03'\x05km\x03+\x03o\x031",
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["qr"]["f64"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvd_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[ 1.6666730531514784 , -0.283751211500066 ,
3.0028872858127387 , -3.539066961520449 ],
[-3.5009517179281326 , -3.0927716333012025 ,
0.7807364288380038 , -1.7085927853808216 ],
[ 1.6481964138894265 , -1.0457448512148775 ,
6.119638350643893 , 0.6798789946663015 ],
[ 0.45137958693199876 , -3.0487560288436093 ,
-3.5653048640383225 , 3.1078238891060703 ]],
[[-0.18274271852634477 , -4.107847311422953 ,
-6.970910660834766 , 0.026925818090434366],
[ 0.609826504319294 , -0.08188345529211022 ,
0.4988730098060886 , 0.5371129476436916 ],
[-1.8478868672212718 , 5.777430685108351 ,
3.6156805021156426 , -5.4328316768756855 ],
[ 2.839181461365762 , -5.931652277530413 ,
-6.898420189145518 , 0.33823029148249784 ]]]),),
expected_outputs=(array([[[ 0.5185598212776665 , 0.08640745162480469 ,
-0.22604389271595665 , 0.8200814731634898 ],
[ 0.058783626848861056, 0.9889125760720038 ,
0.040834454253227216, -0.13011129638493632 ],
[ 0.6559139736574813 , -0.09584764795840599 ,
0.7197537968326447 , -0.20626332559785515 ],
[-0.5453595659120885 , 0.07347719082261084 ,
0.6551268410442749 , 0.5176802762713159 ]],
[[ 0.5266925395219686 , 0.49814307967604127 ,
0.6468374971059552 , 0.23674816434446275 ],
[-0.004398828045962558, -0.15543570506880638 ,
-0.22846771664609714 , 0.9610530132891235 ],
[-0.5473983592728412 , 0.8074523783002769 ,
-0.20514413534806356 , 0.07931946025360624 ],
[ 0.6503311890022828 , 0.27516153535306986 ,
-0.6980828307017146 , -0.11847293172915115 ]]]), array([[ 8.364583718088932 , 5.052514677781009 , 4.631967246026687 ,
2.5284502274249308 ],
[14.318697073130886 , 5.219146075652053 , 2.112106833842507 ,
0.15087576902549965]]), array([[[ 0.17853631156202626 , 0.0774459691431368 ,
0.9039781164136 , -0.3807236167650057 ],
[-0.6814293626119452 , -0.6346900565025612 ,
0.036225601414469004, -0.3626434361038585 ],
[ 0.20775315131153385 , -0.6071183144157039 ,
0.30699755938819207 , 0.7028502535752935 ],
[ 0.6786880265219302 , -0.4717817359132189 ,
-0.29540441665143263 , -0.47910415040709536 ]],
[[ 0.19268560067166468 , -0.641350735371731 ,
-0.7081088748435267 , 0.22388236844328566 ],
[-0.1718035739905713 , 0.19146225969939276 ,
-0.4845129100140278 , -0.8361058396546505 ],
[-0.8808414232069264 , 0.1501707384172312 ,
-0.25997252143926985 , 0.3660347313883319 ],
[ 0.39683016319410813 , 0.7276401491617723 ,
-0.4429936224079104 , 0.34179275213656773 ]]])),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
return %10, %6, %14 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xc6\x05\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19O\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x01\x1d)\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##\x1d\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\t\x00\x00\x00\x00\r\x07Q-S-U/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\x0b\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xc2\x02\x05\x01Q\x03\x07\x01\x07\x04\x9a\x02\x03\x01\x05\tP\x03\x03\x07\x04n\x02\x03-K\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\x03F\x01\x0b\x03\x17\x03\x05\rF\x01\r\x03'\x05\x0f\x11\x03F\x01\x0f\x03)\x03\x13\x03F\x01\x0b\x03\t\x03\x03\x03F\x01\x11\x03-\x03\x15\x05\x06\x01\x03\t\x07\x19\t\x17\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03\x1d\x05\x06\x01\x03\x05\x07!\x0b\x1f\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03%\x05\x06\x01\x03\x05\x07)\r'\x0f\x04\x03\x07#\x1b+\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08E\x15\x05#\x01\x0b379GI\x03K\x03M\x11WY[]/_ae\x03'\x05km\x03+\x03o\x031",
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["qr"]["c64"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvd_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[ 1.0732381 +3.5808065j , 3.9696057 +0.20698753j,
-0.6425436 +0.5669031j , 2.2608232 +3.3495777j ],
[-0.14261106+1.9452835j , 0.7328605 -3.497075j ,
-0.2833068 +2.5005085j , -5.3383408 -0.13752732j],
[ 0.378204 +0.31973448j, -0.82705253-1.2925721j ,
4.6363106 -0.6026158j , -4.2700663 +0.3752707j ],
[-0.5445609 -2.3843932j , -1.5469118 -0.22753051j,
-1.2669541 -5.0028024j , 0.03133653-3.4463193j ]],
[[ 2.5795584 +0.6289093j , 1.27082 +3.8879561j ,
1.9604253 +1.1865004j , -2.399359 +1.3273407j ],
[-1.4925842 +5.878235j , -5.6121607 -3.4182298j ,
-1.9998045 +0.10950515j, 1.9120374 +3.3194423j ],
[ 0.40499273-2.4316337j , 1.6648822 +2.6184802j ,
-4.0471325 +3.133329j , -0.76832575-1.7445682j ],
[-1.895143 +0.5848787j , 2.6240315 +5.021989j ,
1.449456 +3.472618j , -2.0201976 -3.5186274j ]]],
dtype=complex64),),
expected_outputs=(array([[[-0.25113735 -0.21140018j , 0.018947408 -0.64372665j ,
-0.070043676 -0.5433938j , 0.21103051 +0.3643902j ],
[ 0.34577918 -0.55255175j , -0.21863852 +0.16678011j ,
-0.48150375 -0.24859619j , -0.45313844 +0.022904329j],
[-0.07727728 -0.37749314j , -0.107260175 +0.56208193j ,
0.30505398 -0.3713933j , 0.53557503 -0.07908653j ],
[ 0.14142953 +0.5467069j , 0.13847762 +0.4037597j ,
-0.24584742 -0.33873183j , -0.0011076198+0.56897277j ]],
[[-0.15231489 +0.2637356j , -0.333622 -0.3903981j ,
-0.3809538 -0.1761738j , -0.40208697 -0.5528941j ],
[ 0.08597728 -0.6945265j , 0.21902993 -0.445389j ,
-0.33454537 -0.015148819j, -0.28255132 +0.26816013j ],
[-0.27464584 +0.24423108j , 0.3478235 +0.3221502j ,
-0.740009 -0.16378604j , 0.15596838 +0.20345287j ],
[-0.25253323 +0.4675811j , 0.37721652 -0.35055056j ,
0.3376375 -0.15247335j , -0.38508245 +0.40851057j ]]],
dtype=complex64), array([[ 9.668089 , 8.574805 , 4.549492 , 0.5780793],
[12.876013 , 7.7651014, 4.0119534, 2.2829206]], dtype=float32), array([[[-0.38075778 +0.j , 0.14002012 +0.060418203j ,
-0.4637056 +0.22876605j , -0.489978 -0.5695017j ],
[-0.32981405 +0.j , -0.20355047 +0.51292306j ,
-0.34164208 -0.4227407j , -0.19677992 +0.50254196j ],
[-0.3292037 +0.j , 0.17828293 +0.62404454j ,
0.63654786 +0.14848639j , 0.07557414 -0.19353445j ],
[ 0.7986684 +0.j , 0.056182697 +0.4978434j ,
-0.099771045 -0.0043060416j, -0.28370282 -0.14375056j ]],
[[-0.3410219 +0.j , 0.35656637 -0.6787797j ,
0.22528769 -0.27213925j , -0.21556833 +0.35289994j ],
[-0.72291875 +0.j , -0.12834571 -0.11083051j ,
-0.3442186 +0.47835365j , -0.14619395 -0.28275603j ],
[-0.3274419 +0.j , -0.19452412 +0.057822693j ,
0.53667945 -0.43909317j , 0.17421937 -0.5834539j ],
[ 0.50385934 -0.j , -0.06922947 -0.58084977j ,
0.0073776054+0.21678242j , -0.24243501 -0.546006j ]]],
dtype=complex64)),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xcomplex<f32>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f32>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f32>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
%cst_0 = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f32>>) -> (tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>>, tensor<2x4x4xcomplex<f32>>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
return %10, %6, %14 : tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xbds7\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x033\x1b\x07\x17\x07\x07\x07\x0b\x0f\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xf6\x05\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19Q\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1d)\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##!\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x15\t\x00\x00\xc0\x7f\x1f\x17\t\x00\x00\x00\x00\r\x07S-U-W/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03e\x15\x03\x01\x01\x01\x03\x0b%i%%k\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11\x01)\x05\t\x11\r\x1d\t\x13\x03\r)\x01\x11)\x01\r)\x01\x19\x1b)\x03\t\x19)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xda\x02\x05\x01Q\x03\x07\x01\x07\x04\xb2\x02\x03\x01\x05\tP\x03\x03\x07\x04\x86\x02\x03/O\x03\x0b\x13\x00\x05B\x03\x05\x03\x13\x05B\x03\x07\x03\x15\x05B\x03\t\x03\x17\x0bG\x01\x17\x0b\x0b\x05\t\x05\x05\x1b\x03\x01\x03F\x01\r\x03\x1b\x03\x07\rF\x01\x0f\x03+\x05\x11\x13\x03F\x01\x11\x03-\x03\x15\x03F\x01\r\x03\t\x03\x05\x03F\x01\x13\x031\x03\x17\x07\x06\x01\x03\t\x07\x1b\x0b\x19\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03\x1f\x07\x06\x01\x03\x05\x07#\r!\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03'\x07\x06\x01\x03\x05\x07+\x0f)\x0f\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08I\x17\x05#\x01\x0b379GI\x03K\x03M\x03O\x11Y[]_/acg\x03'\x05mo\x03+\x03q\x031",
xla_call_module_version=9,
nr_devices=1,
) # End paste
data_2024_10_08["qr"]["c128"] = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cusolver_gesvd_ffi'],
serialized_date=datetime.date(2024, 10, 8),
inputs=(array([[[-4.699732537674587 -0.18164091805215468j,
-2.6529672987457267 +2.1873545441571416j ,
-4.583623180305504 +0.05141217967419283j,
-1.4684446379730842 +0.5956695859134695j ],
[ 2.217429580673316 -1.6820541069935535j ,
-1.489637886109648 -1.1907523648513954j ,
-5.37070728884717 +0.3011497067658051j ,
-3.5377553884933244 +1.560799473477663j ],
[ 0.4865985561509131 +4.547548126143047j ,
1.9744285723487844 +1.579347193702052j ,
3.662108610237921 -3.8947365367486944j ,
-0.46900368026456773+3.897268760016375j ],
[-3.9057171822032837 +0.894017787659835j ,
-2.665956542656175 -5.446062606216615j ,
6.586068520522582 +7.82920032979931j ,
0.2438426632437082 -2.5324000439269967j ]],
[[ 2.3593407739528036 +0.1518669531658939j ,
0.6163481796609258 +2.2151855304705617j ,
1.1710769743888314 -6.27345033430341j ,
0.9738490103384626 +0.5395897278168652j ],
[-2.4788654273898656 +0.4265527313512031j ,
-1.1807578044484868 -0.0496832499163036j ,
-4.4976038167764765 +1.058853052811918j ,
-1.1727797045618331 -5.283007446632174j ],
[ 2.1607883932036422 +0.15328185326939148j,
0.33959787374719413-0.44019437888510504j,
5.554548585416958 -5.5054723821239575j ,
3.6501512907075853 +2.5205805340930167j ],
[ 1.3385284474824868 -5.630140770855095j ,
-0.27414990799969 -0.46452124262376304j,
1.611578799750626 +8.022764935423794j ,
-2.616414597337455 +0.02175053549931295j]]]),),
expected_outputs=(array([[[ 0.010550471640845436-0.1718885244537855j ,
0.7192461888739357 +0.05423301579908536j ,
-0.3986393309169209 +0.0992550693722462j ,
0.2704504005919772 -0.45626573218528016j ],
[ 0.1757158215796512 -0.288981687849456j ,
0.2705184461915344 +0.2220238094575261j ,
0.8080741699984124 +0.3193834688705089j ,
0.08837920732925486 -0.018389772805120566j],
[ 0.18559809104086944 +0.39876558214096686j ,
-0.035649910397314105-0.49144324300371245j ,
0.15413388674267017 +0.0406886784503357j ,
0.7321684611544159 +0.047628802113382905j],
[-0.8024624385313128 +0.13619820370204241j ,
-0.23383020362294066 +0.2445505198523447j ,
0.12332909849753147 +0.18873939840686926j ,
0.26625208768341435 -0.3182762352506706j ]],
[[-0.19373811197499574 -0.3678054849282306j ,
0.2636906578352804 +0.07973830233400031j ,
-0.22842504880430503 +0.7133040647598473j ,
-0.3539448620401726 +0.25502167015286226j ],
[-0.24171920617553377 +0.3311681855178457j ,
-0.516368216465723 +0.1379819660618946j ,
-0.45968933418302504 +0.04091891748552965j ,
0.20304338982129244 +0.5403786083964173j ],
[ 0.02755060979123564 -0.5776868042861196j ,
0.27416989428331057 -0.1571567241856725j ,
-0.2065632715230957 -0.2247313731964578j ,
0.630967059955887 +0.27268947023173257j ],
[ 0.4031484266707497 +0.40258464155812185j ,
0.22887811112799977 -0.6972670403334557j ,
-0.29285414368652807 +0.2170127687269824j ,
0.03733951130463506 +0.050775060769343766j]]]), array([[15.105031148122244 , 9.10491991264034 , 5.006211740104105 ,
3.446376589720919 ],
[15.343823952995173 , 7.3753715646873195, 3.7496815109995807,
0.8625145657311305]]), array([[[ 0.398346102099369 +0.j ,
0.1371865428555439 +0.20963212142757817j ,
-0.40914166828555465 -0.7712183712683253j ,
-0.01748398034296661 +0.12678422058548106j ],
[-0.4705168176319191 +0.j ,
-0.44062591598236445 +0.5013959991714195j ,
-0.27697957785810795 +0.006226163227350533j,
-0.46230452584693416 +0.2063561296099042j ],
[ 0.6106772453975169 +0.j ,
-0.2591680391544425 -0.21982460883707244j ,
0.0568261232179787 +0.2729248746486013j ,
-0.41495999423821034 +0.511540202216849j ],
[-0.49699860082446057 +0.j ,
0.2086557264813747 -0.5767641952667593j ,
0.0041166738594973 -0.2886775449288538j ,
-0.0862160321999865 +0.5348021741590255j ]],
[[-0.09961676341576099 +0.j ,
-0.045561706463074156+0.0200549951728976j ,
0.6993928144861429 +0.5554242311281644j ,
-0.27729760005165593 +0.3362411099877399j ],
[ 0.9183964339371264 +0.j ,
0.18513592789573766 +0.048643367601494444j,
-0.07592192157413628 +0.08808182166509272j ,
0.022653328707496152+0.3253779069593685j ],
[-0.36489390866852983 +0.j ,
0.5302626885445338 -0.13646920230359977j ,
-0.339383962926636 -0.005000573360891423j,
-0.01709833277905625 +0.6719756248311359j ],
[ 0.11609016321265409 +0.j ,
0.16300036810869795 -0.7965607051693728j ,
0.13402108640742152 -0.23592994174105483j ,
-0.4709038361949722 -0.17340699243933513j ]]])),
mlir_module_text=r"""
#loc1 = loc("operand")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x4x4xcomplex<f64>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f64>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f64>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
%cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f64>>) -> (tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>>, tensor<2x4x4xcomplex<f64>>, tensor<2xi32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
return %10, %6, %14 : tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xbds7\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0bO/\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x033\x1b\x07\x17\x07\x07\x07\x0b\x0f\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02&\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19Q\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1d)\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##!\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x15\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17\t\x00\x00\x00\x00\r\x07S-U-W/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03e\x15\x03\x01\x01\x01\x03\x0b%i%%k\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11\x01)\x05\t\x11\r\x1d\x0b\x13\x03\r)\x01\x11)\x01\r)\x01\x19\x1b)\x03\t\x19)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xda\x02\x05\x01Q\x03\x07\x01\x07\x04\xb2\x02\x03\x01\x05\tP\x03\x03\x07\x04\x86\x02\x03/O\x03\x0b\x13\x00\x05B\x03\x05\x03\x13\x05B\x03\x07\x03\x15\x05B\x03\t\x03\x17\x0bG\x01\x17\x0b\x0b\x05\t\x05\x05\x1b\x03\x01\x03F\x01\r\x03\x1b\x03\x07\rF\x01\x0f\x03+\x05\x11\x13\x03F\x01\x11\x03-\x03\x15\x03F\x01\r\x03\t\x03\x05\x03F\x01\x13\x031\x03\x17\x07\x06\x01\x03\t\x07\x1b\x0b\x19\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03\x1f\x07\x06\x01\x03\x05\x07#\r!\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03'\x07\x06\x01\x03\x05\x07+\x0f)\x0f\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08I\x17\x05#\x01\x0b379GI\x03K\x03M\x03O\x11Y[]_/acg\x03'\x05mo\x03+\x03q\x031",
xla_call_module_version=9,
nr_devices=1,
) # End paste

View File

@ -15,6 +15,7 @@
from __future__ import annotations
from collections.abc import Callable
import enum
import functools
from functools import partial
import math
@ -308,6 +309,13 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
return q, r
class SvdAlgorithm(enum.Enum):
"""Enum for SVD algorithm."""
DEFAULT = "default"
QR = "QR"
JACOBI = "Jacobi"
@overload
def svd(
x: ArrayLike,
@ -315,6 +323,7 @@ def svd(
full_matrices: bool = True,
compute_uv: Literal[True],
subset_by_index: tuple[int, int] | None = None,
algorithm: SvdAlgorithm | None = None,
) -> tuple[Array, Array, Array]:
...
@ -326,6 +335,7 @@ def svd(
full_matrices: bool = True,
compute_uv: Literal[False],
subset_by_index: tuple[int, int] | None = None,
algorithm: SvdAlgorithm | None = None,
) -> Array:
...
@ -337,6 +347,7 @@ def svd(
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
algorithm: SvdAlgorithm | None = None,
) -> Array | tuple[Array, Array, Array]:
...
@ -348,6 +359,7 @@ def svd(
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
algorithm: SvdAlgorithm | None = None,
) -> Array | tuple[Array, Array, Array]:
"""Singular value decomposition.
@ -360,6 +372,7 @@ def svd(
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
algorithm=algorithm,
)
if compute_uv:
s, u, v = result
@ -1062,15 +1075,12 @@ batching.primitive_batchers[eigh_p] = _eigh_batching_rule
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cpu'),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'),
platform='rocm')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'),
platform='rocm')
mlir.register_lowering(
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
platform='tpu')
@ -1890,22 +1900,26 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering))
# Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None):
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None,
algorithm=None):
return dispatch.apply_primitive(
svd_p,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
algorithm=algorithm,
)
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
algorithm=None):
del algorithm # unused
if isinstance(operand, ShapedArray):
batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
rank = min(m, n)
rank = core.min_dim(m, n)
if subset_by_index is not None:
if full_matrices and subset_by_index != (0, rank):
raise ValueError("full_matrices and subset_by_index cannot both be set")
@ -1927,12 +1941,14 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
@config.default_matmul_precision("float32")
def _svd_jvp_rule(
primals, tangents, *, full_matrices, compute_uv, subset_by_index
primals, tangents, *, full_matrices, compute_uv, subset_by_index,
algorithm=None,
):
A, = primals
dA, = tangents
s, U, Vt = svd_p.bind(
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index,
algorithm=algorithm,
)
if compute_uv and full_matrices:
@ -1993,24 +2009,18 @@ def _empty_svd(a, *, full_matrices, compute_uv):
def _svd_cpu_gpu_lowering(
gesvd_impl,
ctx,
operand,
*,
full_matrices,
compute_uv,
subset_by_index,
platform: str,
target_name_prefix: str,
algorithm=None,
):
operand_aval, = ctx.avals_in
s_aval = ctx.avals_out[0]
m, n = operand_aval.shape[-2:]
# Since the last two dimensions (m, n) are used to compute the workspace
# size, we support dynamic dimensions only for the batch size for now.
if not is_constant_shape([m, n]):
raise NotImplementedError(
"Shape polymorphism for native serialization for svd on CPU and GPU is "
f"implemented only for the batch dimensions: {operand_aval.shape}")
batch_dims = operand_aval.shape[:-2]
if not (subset_by_index is None or subset_by_index == (0, min(m, n))):
@ -2023,22 +2033,43 @@ def _svd_cpu_gpu_lowering(
full_matrices=full_matrices,
compute_uv=compute_uv,
)
if platform in ["cuda", "rocm"]:
if not is_constant_shape(operand_aval.shape):
# TODO(necula): remove the platform kwarg when we implement GPU support.
if target_name_prefix == "cpu":
if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT:
raise NotImplementedError(
"Shape polymorphism for native serialization for SVD is not "
f"implemented, try to upgrade jaxlib; b/261671778; {operand_aval.shape}")
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv)
"The SVD algorithm parameter is not implemented on CPU.")
target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype)
nb = len(batch_dims)
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
result_layouts = [layout, tuple(range(nb, -1, -1)), layout, layout,
tuple(range(nb - 1, -1, -1))]
mode = lapack._svd_computation_attr(compute_uv=compute_uv,
full_matrices=full_matrices)
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
result_layouts=result_layouts,
operand_output_aliases={0: 0})
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
if compute_uv:
s_aval, u_aval, vt_aval = ctx.avals_out
else:
s_aval, = ctx.avals_out
# TODO(danfm): It should be possible to skip instantiating these arrays
# when they are not used.
u_aval = ShapedArray((*batch_dims, m,
m if full_matrices else core.min_dim(m, n)),
operand_aval.dtype)
vt_aval = ShapedArray((*batch_dims,
n if full_matrices else core.min_dim(m, n), n),
operand_aval.dtype)
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, u_aval, vt_aval,
info_aval])
_, s, u, vt, info = rule(sub_ctx, operand, mode=mode)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
s, u, vt, info = gesvd_impl(ctx, operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
a_shape_vals=a_shape_vals)
s, u, vt, info = _svd_gpu_sub_lowering(ctx, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
target_name_prefix=target_name_prefix,
algorithm=algorithm)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
@ -2071,9 +2102,115 @@ def _svd_cpu_gpu_lowering(
return result
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
batch_dims = a.shape[:-2]
def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv,
target_name_prefix, algorithm):
operand_aval, = ctx.avals_in
if compute_uv:
s_aval, u_aval, vt_aval = ctx.avals_out
else:
s_aval, = ctx.avals_out
u_aval = vt_aval = ShapedArray((), operand_aval.dtype)
batch_dims = operand_aval.shape[:-2]
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
nb = len(batch_dims)
m, n = operand_aval.shape[-2:]
k = core.min_dim(m, n)
transposed = False
kwargs = {}
# The Jacobi algorithm appears to outperform the default QR algorithm for
# small to medium sized matrices. See:
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
# slide 5. With this in mind, we default to using the Jacobi algorithm for
# matrices smaller than 1024x1024.
#
# Note that the Jacobi algorithm is only used by default for matrices with
# concrete matrix dimensions. When using dynamic shapes, we always use the
# default QR algorithm, but users can (in principle) override this behavior
# by passing `use_jacobi=True`.
#
# TODO(danfm): Since this was originally implemented, hipSolver appers to
# have added support for the Jacobi algorithm, so we should investigate
# removing this condition.
if algorithm is None or algorithm == SvdAlgorithm.DEFAULT:
try:
use_jacobi = target_name_prefix == "cu" and m <= 1024 and n <= 1024
except core.InconclusiveDimensionOperation:
use_jacobi = False
else:
use_jacobi = algorithm == SvdAlgorithm.JACOBI
if use_jacobi:
target_name = f"{target_name_prefix}solver_gesvdj_ffi"
# The gesvdjbatched kernel doesn't support "econ" mode, but it also only
# supports matrices up to 32x32, so it's always worth using the batched
# version and then slicing afterwards when the matrix is small enough.
try:
econ = not full_matrices and m > 32 and n > 32
except core.InconclusiveDimensionOperation:
econ = False
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
else:
target_name = f"{target_name_prefix}solver_gesvd_ffi"
econ = not full_matrices
# Because the base gesvd kernel only supports matrices where m >= n, we.
transposed = m < n
kwargs = {"transposed": transposed}
if transposed:
layout = tuple(range(nb + 1, -1, -1))
else:
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
result_layouts = [layout, tuple(range(nb, -1, -1)),
layout if use_jacobi or compute_uv else (),
layout if use_jacobi or compute_uv else (),
tuple(range(nb - 1, -1, -1))]
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
result_layouts=result_layouts,
operand_output_aliases={0: 0})
if use_jacobi:
# When using the Jacobi algorithm, the U and V matrices must always be
# allocated even if compute_uv is False.
u_aval = ShapedArray((*batch_dims, m, k if econ else m), u_aval.dtype)
v_aval = ShapedArray((*batch_dims, n, k if econ else n), vt_aval.dtype)
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, u_aval, v_aval,
info_aval])
elif transposed:
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, vt_aval, u_aval,
info_aval])
else:
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, u_aval, vt_aval,
info_aval])
_, s, u, vt, info = rule(sub_ctx, operand, full_matrices=not econ,
compute_uv=compute_uv, **kwargs)
if use_jacobi and compute_uv:
vt = hlo.transpose(
vt,
mlir.dense_int_array(np.array(tuple(range(nb)) + (nb + 1, nb))))
if np.issubdtype(operand_aval.dtype, np.complexfloating):
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
if not full_matrices and not econ:
nd = len(operand_aval.shape)
u = mlir.slice_op(ctx, u, ctx.avals_out[1],
start_indices=np.zeros([nd], np.int64),
limit_indices=batch_dims + (m, k),
strides=np.ones([nd], np.int64))
vt = mlir.slice_op(ctx, vt, ctx.avals_out[2],
start_indices=np.zeros([nd], np.int64),
limit_indices=batch_dims + (k, n),
strides=np.ones([nd], np.int64))
if transposed:
return s, vt, u, info
else:
return s, u, vt, info
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None):
if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT:
raise NotImplementedError(
"The SVD algorithm parameter is not implemented on TPU.")
batch_dims = a.shape[:-2]
fn = partial(
lax_svd.svd,
full_matrices=full_matrices,
@ -2092,8 +2229,9 @@ def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
def _svd_tpu_lowering_rule(
ctx, operand, *, full_matrices, compute_uv, subset_by_index
ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None
):
del algorithm # unused
operand_aval, = ctx.avals_in
m, n = operand_aval.shape[-2:]
@ -2115,7 +2253,8 @@ def _svd_tpu_lowering_rule(
def _svd_batching_rule(
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index,
algorithm=None,
):
x, = batched_args
bd, = batch_dims
@ -2125,6 +2264,7 @@ def _svd_batching_rule(
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
algorithm=algorithm,
)
if compute_uv:
@ -2141,18 +2281,14 @@ ad.primitive_jvps[svd_p] = _svd_jvp_rule
batching.primitive_batchers[svd_p] = _svd_batching_rule
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo,
platform='cpu'),
svd_p, partial(_svd_cpu_gpu_lowering, target_name_prefix='cpu'),
platform='cpu')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd,
platform='cuda'),
platform='cuda')
svd_p, partial(_svd_cpu_gpu_lowering, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd,
platform='rocm'),
platform='rocm')
svd_p, partial(_svd_cpu_gpu_lowering, target_name_prefix='hip'),
platform='rocm')
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)

View File

@ -3314,6 +3314,7 @@ def _svd(
full_matrices: bool,
compute_uv: bool,
subset_by_index: tuple[int, int] | None = None,
algorithm: lax.linalg.SvdAlgorithm | None = None,
):
if not (
subset_by_index is None
@ -3321,6 +3322,9 @@ def _svd(
):
raise NotImplementedError("subset_by_index is not implemented")
if algorithm is not None and algorithm != lax.linalg.SvdAlgorithm.DEFAULT:
raise NotImplementedError("SVD algorithm is not implemented")
result = tf.linalg.svd(operand, full_matrices, compute_uv)
if not compute_uv:
return result,

View File

@ -30,6 +30,7 @@ from jax._src.lax.linalg import (
qr_p as qr_p,
svd as svd,
svd_p as svd_p,
SvdAlgorithm as SvdAlgorithm,
triangular_solve as triangular_solve,
triangular_solve_p as triangular_solve_p,
tridiagonal as tridiagonal,

View File

@ -14,7 +14,6 @@
from functools import partial
import importlib
import math
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
@ -119,138 +118,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = math.prod(batch_dims)
if ir.ComplexType.isinstance(a_type.element_type):
singular_vals_type = ir.ComplexType(a_type.element_type).element_type
else:
singular_vals_type = a_type.element_type
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
# NVIDIA's batched Jacobi solver supports a maximum matrix size of 32x32, but
# the unbatched solver has no such limit. The unbatched solver appears to
# outperform gesvd for small-moderate matrices, e.g., see:
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
# slide 5.
if have_jacobi_solver and m <= 1024 and n <= 1024:
# The gesvdjbatched kernel doesn't support "econ" mode. We will use that
# kernel only if b > 1 and m <= 32 and n <= 32.
econ = not full_matrices and (b <= 1 or m > 32 or n > 32)
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
k = min(m, n)
matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
_, s, u, v, info, _ = custom_call(
f"{platform}solver_gesvdj",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
a_type.element_type),
ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[matrix_layout],
result_layouts=[
matrix_layout,
vector_layout,
matrix_layout,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0}).results
vt = hlo.transpose(
v,
dense_int_array(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
if np.issubdtype(dtype, np.complexfloating):
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
if not full_matrices and not econ:
u = hlo.slice(
u,
dense_int_array(np.zeros([len(dims)], np.int64)),
dense_int_array(np.array(batch_dims + (m, min(m, n)))),
dense_int_array(np.ones([len(dims)], np.int64)))
vt = hlo.slice(
vt,
dense_int_array(np.zeros([len(dims)], np.int64)),
dense_int_array(np.array(batch_dims + (min(m, n), n))),
dense_int_array(np.ones([len(dims)], np.int64)))
elif m < n:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
k = n if full_matrices else m
matrix_layout = (num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))
_, s, vt, u, info, _ = custom_call(
f"{platform}solver_gesvd",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type),
ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[matrix_layout],
result_layouts=[
matrix_layout,
vector_layout,
matrix_layout,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0}).results
else:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
k = m if full_matrices else n
matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
_, s, u, vt, info, _ = custom_call(
f"{platform}solver_gesvd",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type),
ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[matrix_layout],
result_layouts=[
matrix_layout,
vector_layout,
matrix_layout,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0}).results
return s, u, vt, info
cuda_gesvd = partial(_gesvd_hlo, "cu", _cusolver, True)
rocm_gesvd = partial(_gesvd_hlo, "hip", _hipsolver, False)
def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
a_type = ir.RankedTensorType(a.type)

View File

@ -213,133 +213,6 @@ def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False,
return out[:2]
# # ?gesdd: Singular value decomposition
def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
a_shape_vals: tuple[DimensionSize, ...]):
a_type = ir.RankedTensorType(a.type)
assert len(a_shape_vals) >= 2
m, n = a_shape_vals[-2:]
assert type(m) is int
assert type(n) is int
batch_dims_vals = a_shape_vals[:-2]
num_bd = len(batch_dims_vals)
fn_base = prepare_lapack_call(fn_base="gesdd", dtype=dtype)
i32_type = ir.IntegerType.get_signless(32)
workspace: list[ShapeTypePair]
# TODO(b/344892332): Remove the old kernel after the compatibility period.
if ctx.is_forward_compat():
fn = fn_base
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
if dtype == np.float32:
singular_vals_type = ir.F32Type.get()
lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.float64:
singular_vals_type = ir.F64Type.get()
lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.complex64:
singular_vals_type = ir.F32Type.get()
lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0], [0]]
elif dtype == np.complex128:
singular_vals_type = ir.F64Type.get()
lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0], [0]]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals + (min(m, n),), singular_vals_type),
(batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type),
(batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type),
(batch_dims_vals, i32_type),
] + workspace
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
return custom_call(
fn,
result_types=result_types,
operands=[hlo_s32(int(full_matrices)), hlo_s32(int(compute_uv)), batch_size_val,
hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a],
operand_layouts=[scalar_layout] * 6 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
layout,
layout,
tuple(range(num_bd - 1, -1, -1)),
] + workspace_layouts,
operand_output_aliases={6: 0},
result_shapes=result_shapes
).results[1:5]
fn = fn_base + "_ffi"
mode_attr = _svd_computation_attr(
compute_uv=compute_uv, full_matrices=full_matrices
)
if dtype == np.float32 or dtype == np.complex64:
singular_vals_type = ir.F32Type.get()
elif dtype == np.float64 or dtype == np.complex128:
singular_vals_type = ir.F64Type.get()
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
a_elem_type = a_type.element_type
shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_elem_type),
(batch_dims_vals + (min(m, n),), singular_vals_type),
(batch_dims_vals + (m, m if full_matrices else min(m, n)), a_elem_type),
(batch_dims_vals + (n if full_matrices else min(m, n), n), a_elem_type),
(batch_dims_vals, i32_type),
]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
return custom_call(
fn,
result_types=result_types,
operands=[a],
operand_layouts=[layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
layout,
layout,
tuple(range(num_bd - 1, -1, -1)),
],
operand_output_aliases={0: 0},
result_shapes=result_shapes,
backend_config={
"mode": mode_attr,
},
api_version=4,
).results[1:]
# # geev: Nonsymmetric eigendecomposition (eig)
def geev_hlo(ctx, dtype, input, *,

View File

@ -47,6 +47,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenb
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf
from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
@ -135,6 +136,7 @@ class CompatTest(bctu.CompatTestBase):
cuda_lu_cusolver_getrf.data_2024_08_19,
cuda_qr_cusolver_geqrf.data_2024_09_26,
cuda_eigh_cusolver_syev.data_2024_09_30,
cuda_svd_cusolver_gesvd.data_2024_10_08,
rocm_qr_hipsolver_geqrf.data_2024_08_05,
rocm_eigh_hipsolver_syev.data_2024_08_05,
cpu_schur_lapack_gees.data_2023_07_16,
@ -166,6 +168,7 @@ class CompatTest(bctu.CompatTestBase):
# The following require ROCm to test
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi",
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered,
@ -646,28 +649,48 @@ class CompatTest(bctu.CompatTestBase):
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
shape = (2, 4, 4)
input = jtu.rand_default(self.rng())(shape, dtype)
# del input # Input is in the testdata, here for readability
def func(input):
return lax.linalg.svd(input, full_matrices=True, compute_uv=True)
def func(operand):
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True)
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
info = cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
data = self.load_testdata(info)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
*data.inputs))
data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
input))
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))
*data.inputs),
expect_current_custom_calls=info["custom_call_targets"])
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}",
dtype_name=dtype_name, algorithm_name=algorithm_name)
for dtype_name in ("f32", "f64", "c64", "c128")
for algorithm_name in ("qr", "jacobi"))
@jax.default_matmul_precision("float32")
def test_gpu_svd_solver_gesvd(self, dtype_name, algorithm_name):
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
def func(operand):
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True,
algorithm=algorithm)
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
algorithm = dict(qr=lax.linalg.SvdAlgorithm.QR,
jacobi=lax.linalg.SvdAlgorithm.JACOBI)[algorithm_name]
info = cuda_svd_cusolver_gesvd.data_2024_10_08[algorithm_name][dtype_name]
data = self.load_testdata(info)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
*data.inputs))
@jtu.parameterized_filterable(
kwargs=[

View File

@ -48,7 +48,6 @@ from jax._src.export import shape_poly
from jax._src.export import shape_poly_decision
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import version as jaxlib_version
import numpy as np
config.parse_flags_with_absl()
@ -2880,6 +2879,28 @@ _POLY_SHAPE_TEST_HARNESSES = [
((2, 3, 4, 5), "b1, b2, m, n"),
]
],
[
PolyHarness( # pylint: disable=g-complex-comprehension
"svd", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}_{compute_uv=}",
lambda x, full_matrices, compute_uv: lax.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv),
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices), StaticArg(compute_uv)],
polymorphic_shapes=[poly],
symbolic_constraints=constraints)
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
for compute_uv in [True, False]
for full_matrices in ([True, False] if compute_uv else [True])
for shape, poly, constraints in [
((2, 0, 4), "b, ...", ()),
((2, 4, 0), "b, ...", ()),
((2, 3, 4, 4), "b1, b2, ...", ()),
((2, 3, 4, 5), "b1, b2, ...", ()),
((2, 3, 8, 4), "b1, b2, ...", ()),
# The constraints listed here are only for the GPU implementation
# which selects an algorithm based on the size of the matrix.
((5, 4), "m, n", ["n <= m", "m <= 32", "n <= 32"]),
((2, 3, 4, 5), "b1, b2, m, n", ["m <= n", "m <= 32", "n <= 32"]),
]
],
[
# The random primitive tests, with threefry (both partitionable and
# non-partitionable), and unsafe_rbg.
@ -3498,23 +3519,6 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
if harness.expect_error == expect_error_associative_scan and jtu.test_device_matches(["tpu"]):
harness.expect_error = None
# Exclude some harnesses that are known to fail for native serialization
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"vmap_svd:gpu",
}
name_device_key = f"{harness.group_name}:{jtu.device_under_test()}"
if name_device_key in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
# This list keeps track of the minimum jaxlib version that supports shape
# polymorphism for some new primitives as we add them. This check is
# required so that we can still run the test suite with older versions of
# jaxlib.
version_gated = {}
if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version:
raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}")
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
@ -3553,11 +3557,11 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("JAX implements eig only on CPU.")
if (harness.group_name == "eigh" and
if (harness.group_name in ("eigh", "svd") and
not harness.polymorphic_shapes[0].endswith("...") and
jtu.test_device_matches(["tpu"])):
raise unittest.SkipTest(
"Shape polymorphsim for Eigh is only supported for batch dimensions on TPU.")
"Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.")
config_flags = harness.override_jax_config_flags
# Update this here rather than in harness object because vmap_random_gamma is derived