mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized. This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits: 1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API. 2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed. Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now. PiperOrigin-RevId: 687106965
This commit is contained in:
parent
3e634d9530
commit
8361eb58e1
@ -962,6 +962,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
# eigh on GPU
|
||||
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
|
||||
# svd on GPU
|
||||
"cusolver_gesvd_ffi", "cusolver_gesvdj_ffi",
|
||||
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
|
||||
# lu on TPU
|
||||
"LuDecomposition",
|
||||
# ApproxTopK on TPU
|
||||
|
@ -0,0 +1,818 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# type: ignore
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
||||
data_2024_10_08 = {"jacobi": {}, "qr": {}}
|
||||
|
||||
data_2024_10_08["jacobi"]["f32"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvdj_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[ 4.0477114 , -1.6619755 , 3.410165 , -0.3096872 ],
|
||||
[ 1.7056831 , 0.17458144, 0.12622283, 1.8056804 ],
|
||||
[-0.21965337, 1.7691472 , -0.8960539 , 1.1495945 ],
|
||||
[ 1.4272052 , 3.2477028 , -1.5676605 , -1.7347798 ]],
|
||||
|
||||
[[-3.914366 , -2.8975718 , 1.5051024 , 1.7379241 ],
|
||||
[-5.045383 , 1.4189544 , -0.61290324, 1.7456682 ],
|
||||
[-4.823716 , 0.32512662, -1.9951408 , 3.632175 ],
|
||||
[ 0.8716193 , -0.24001008, 1.5933404 , 1.0177776 ]]],
|
||||
dtype=float32),),
|
||||
expected_outputs=(array([[[ 0.9018128 , 0.31748173 , 0.090800084, -0.27873662 ],
|
||||
[ 0.1856214 , 0.18972664 , -0.78390217 , 0.56128997 ],
|
||||
[-0.25094122 , 0.20106053 , -0.55834264 , -0.7647598 ],
|
||||
[-0.29884213 , 0.90707415 , 0.25594372 , 0.1496738 ]],
|
||||
|
||||
[[-0.45493636 , -0.8541334 , -0.20037504 , -0.15276858 ],
|
||||
[-0.57877773 , 0.30431184 , -0.4479597 , 0.6097069 ],
|
||||
[-0.6747176 , 0.29090276 , 0.57102525 , -0.3661442 ],
|
||||
[ 0.052960183, -0.30532822 , 0.65811265 , 0.68619025 ]]],
|
||||
dtype=float32), array([[5.974016 , 4.183989 , 2.6312675, 0.5687128],
|
||||
[9.106636 , 3.995191 , 1.8342099, 1.6058134]], dtype=float32), array([[[ 0.60185665 , -0.48223644 , 0.6347658 , 0.047846846],
|
||||
[ 0.6833445 , 0.6709123 , -0.1184354 , -0.26246953 ],
|
||||
[-0.18304126 , -0.16886286 , 0.11772618 , -0.96131265 ],
|
||||
[ 0.37054697 , -0.53741086 , -0.75444424 , -0.0685464 ]],
|
||||
|
||||
[[ 0.8786726 , 0.0290856 , 0.12085139 , -0.46095958 ],
|
||||
[ 0.034706447, 0.76957005 , -0.635503 , -0.051896986],
|
||||
[ 0.4708456 , -0.014900906, -0.06417345 , 0.8797526 ],
|
||||
[-0.07095488 , 0.6377255 , 0.75987667 , 0.10420583 ]]],
|
||||
dtype=float32)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%3 = stablehlo.compare EQ, %0#4, %2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
|
||||
%6 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%7 = stablehlo.select %6, %0#1, %5 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%11 = stablehlo.select %10, %0#2, %9 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%15 = stablehlo.select %14, %1, %13 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
|
||||
return %11, %7, %15 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#'+\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x1f\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02\x12\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x1d\x15\x03\x05\x1f\x03\x03\x19M\x05!\x05#\x17\x1f\xba\n\x1b\x05%\x1d'\x1d)\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x1d+\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##\x1f\x03\x079=A\r\x05);!#\x1d-\r\x05)?!#\x1d/\r\x05)C!#\x1d1\x1d3\x1d5\x1f\x11\t\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\x00\x00\r\x05O-Q-\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x03%\x03\x03a\x15\x03\x01\x01\x01\x03\x0b%e%%g\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\t\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x03\r\x0b)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b\x04\xe2\x02\x05\x01Q\x03\x07\x01\x07\x04\xba\x02\x03\x01\x05\tP\x03\x03\x07\x04\x8e\x02\x03/O\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\rF\x01\x0b\x03\x05\x03\r\x03F\x01\r\x03\x17\x03\x05\x0fF\x01\x0f\x03)\x05\x0f\x13\x03F\x01\x11\x03+\x03\x15\x03F\x01\r\x03\t\x03\x03\x03F\x01\x13\x03/\x03\x17\x05\x06\x01\x03\t\x07\x1b\t\x19\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03\x1f\x05\x06\x01\x03\x05\x07#\x0b!\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03'\x05\x06\x01\x03\x05\x07+\x11)\x11\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xe6\x06?)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1b\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08I\x17\x05#\x01\x0b157EG\x03I\x03K\x11SUWY[]_c\x03i\x03'\x05km\x03+\x03o\x03/",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["jacobi"]["f64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvdj_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[-0.23389781892940775, -0.20108449262485306,
|
||||
-0.5666115573270456 , -2.4789757754536694 ],
|
||||
[-1.8555613779538866 , 2.9994543506103533 ,
|
||||
1.087201973266891 , -1.012848084871914 ],
|
||||
[-3.195786395215201 , -2.483536010628656 ,
|
||||
-0.9470206294018368 , -0.4080455549606986 ],
|
||||
[ 0.41241574420074084, -8.059831117309347 ,
|
||||
-0.7882929058035465 , 2.8856408696497664 ]],
|
||||
|
||||
[[ 1.7959513084827519 , -6.006170699401665 ,
|
||||
0.9365840264144545 , -2.8339481807478486 ],
|
||||
[ 3.37344261819673 , -2.809695890050033 ,
|
||||
0.9096330403907014 , -0.22091594063236752],
|
||||
[-1.8994569852606458 , 0.8593563734072376 ,
|
||||
3.3970755287480623 , 1.162324876281341 ],
|
||||
[ 1.2730096750276905 , 3.1397664846677484 ,
|
||||
-4.625276688205361 , -3.0618323131122303 ]]]),),
|
||||
expected_outputs=(array([[[ 0.05538281351237727 , -0.34543122319556363 ,
|
||||
-0.875138554715149 , 0.33427911101376073 ],
|
||||
[ 0.35132368873545367 , -0.3583481459523727 ,
|
||||
0.4466199100964102 , 0.7407353967047108 ],
|
||||
[-0.2350018651348275 , -0.8649389618871873 ,
|
||||
0.17010144637596905 , -0.40953658387679276 ],
|
||||
[-0.904587493327162 , 0.06437754689044517 ,
|
||||
0.07568793759482184 , 0.41454593771389403 ]],
|
||||
|
||||
[[-0.7655776828405998 , -0.3259951618597633 ,
|
||||
-0.5487654998241592 , 0.08046360781859588 ],
|
||||
[-0.4502679475178967 , -0.18435823314270097 ,
|
||||
0.7918156001863894 , 0.3691867719894481 ],
|
||||
[ 0.011780696812478626, 0.5608039473610089 ,
|
||||
-0.23310619898043075 , 0.7943687102371408 ],
|
||||
[ 0.4593591211209926 , -0.7384024166677944 ,
|
||||
-0.13246124527990302 , 0.47561022634191513 ]]]), array([[9.502458469794536 , 4.039322683970868 , 2.3634256270400944,
|
||||
0.7309141361765127],
|
||||
[8.368723268857098 , 7.018568450518651 , 2.518568829019885 ,
|
||||
1.4845675373399827]]), array([[[-0.030192917862639196, 0.9383993407299717 ,
|
||||
0.13535558060292174 , -0.31650265690534246 ],
|
||||
[ 0.8755039760059299 , 0.15444343242930395 ,
|
||||
0.14222562381109818 , 0.4352147586063772 ],
|
||||
[-0.4808404523560293 , 0.20440997041138523 ,
|
||||
0.32185308272374646 , 0.7895692601131882 ],
|
||||
[ 0.03706258337996993 , -0.23188459951757628 ,
|
||||
0.9262080391966009 , -0.2949484116712019 ]],
|
||||
|
||||
[[-0.2785970537508859 , 0.8741728222022489 ,
|
||||
-0.3837203651304145 , 0.10471026668126972 ],
|
||||
[-0.45773005941421924 , 0.09111437041907021 ,
|
||||
0.690652044371576 , 0.5524320028900885 ],
|
||||
[ 0.7781161676166882 , 0.18065797801542668 ,
|
||||
0.010754987459331431, 0.6014833787542634 ],
|
||||
[ 0.32772260227744276 , 0.4414552564025818 ,
|
||||
0.612897026615641 , -0.5675142177221092 ]]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xf64>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%3 = stablehlo.compare EQ, %0#4, %2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
|
||||
%6 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%7 = stablehlo.select %6, %0#1, %5 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%11 = stablehlo.select %10, %0#2, %9 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%15 = stablehlo.select %14, %1, %13 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
|
||||
return %11, %7, %15 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b'ML\xefR\rStableHLO_v1.7.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02"\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x1d\x15\x03\x05\x1f\x03\x03\x19M\x05!\x05#\x17\x1f\xba\n\x1b\x05%\x1d\'\x1d)\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\'\x01\x1d+\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##\x1f\x03\x079=A\r\x05);!#\x1d-\r\x05)?!#\x1d/\r\x05)C!#\x1d1\x1d3\x1d5\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\t\x00\x00\x00\x00\r\x05O-Q-\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x03%\x03\x03a\x15\x03\x01\x01\x01\x03\x0b%e%%g\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x191\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\x0b\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x03\r\x0b)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b\x04\xe2\x02\x05\x01Q\x03\x07\x01\x07\x04\xba\x02\x03\x01\x05\tP\x03\x03\x07\x04\x8e\x02\x03/O\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\rF\x01\x0b\x03\x05\x03\r\x03F\x01\r\x03\x17\x03\x05\x0fF\x01\x0f\x03)\x05\x0f\x13\x03F\x01\x11\x03+\x03\x15\x03F\x01\r\x03\t\x03\x03\x03F\x01\x13\x03/\x03\x17\x05\x06\x01\x03\t\x07\x1b\t\x19\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03\x1f\x05\x06\x01\x03\x05\x07#\x0b!\x03F\x01\x11\x03\x1b\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1d\x03\'\x05\x06\x01\x03\x05\x07+\x11)\x11\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xe6\x06?)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1b\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08I\x17\x05#\x01\x0b157EG\x03I\x03K\x11SUWY[]_c\x03i\x03\'\x05km\x03+\x03o\x03/',
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["jacobi"]["c64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvdj_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[ 2.4649541 -5.8653884j , 5.5100183 +2.0214202j ,
|
||||
-2.4541297 +1.862114j , 5.4709225 +4.409564j ],
|
||||
[ 1.1091617 +2.325679j , -5.0506334 +5.5802264j ,
|
||||
-1.7254959 -1.5569435j , 3.002013 -2.7583091j ],
|
||||
[ 0.8154569 +5.66862j , -0.7711735 +1.8621845j ,
|
||||
1.2456422 -1.1770611j , 0.03156909-0.22670403j],
|
||||
[ 3.9012303 -2.0669405j , -2.7752936 -0.71004313j,
|
||||
-1.0354352 -5.5713825j , 1.554125 +0.9581067j ]],
|
||||
|
||||
[[-2.077837 +6.2613506j , -2.0213401 -3.2755377j ,
|
||||
-2.1061401 -0.4942127j , -5.098616 +3.4436114j ],
|
||||
[ 3.6104472 +0.75928044j, 1.3155019 -3.6494553j ,
|
||||
0.58335614-0.8751654j , -1.1484178 -4.0733714j ],
|
||||
[-0.4858576 +4.38415j , -2.3318157 +2.744366j ,
|
||||
-0.7987209 -0.23303579j, -2.6747904 +1.5206293j ],
|
||||
[ 4.013358 +0.978174j , -2.707136 -0.29939318j,
|
||||
-2.5241148 -1.44767j , -1.8191104 +0.26034543j]]],
|
||||
dtype=complex64),),
|
||||
expected_outputs=(array([[[-7.7744454e-01+0.09018909j , -2.1625498e-01-0.16622283j ,
|
||||
-1.6743444e-01+0.5261705j , 8.9576304e-02+0.01171514j ],
|
||||
[ 1.2709992e-01-0.38680157j , -3.4196457e-01+0.6261395j ,
|
||||
1.3424198e-01+0.37196818j , -1.5489596e-01+0.38061285j ],
|
||||
[ 2.6221693e-01-0.25399417j , 1.7437853e-01+0.053393736j ,
|
||||
7.3115915e-02+0.51672214j , 2.4561302e-01-0.7076709j ],
|
||||
[-5.0269447e-02-0.29305094j , -5.6865913e-01-0.24491112j ,
|
||||
4.3156582e-01-0.28307965j , 5.0845784e-01-0.05768533j ]],
|
||||
|
||||
[[-1.5179910e-02+0.8092423j , -1.0622249e-01+0.09009284j ,
|
||||
-5.3503644e-01-0.053553585j , 1.8145569e-01+0.058631808j ],
|
||||
[-1.7091817e-01+0.0025223175j, 7.6241887e-01+0.2851526j ,
|
||||
-3.0385127e-02-0.115484476j , 3.0117261e-01-0.45080122j ],
|
||||
[ 1.7759532e-01+0.4123873j , -2.5654158e-01+0.052343685j ,
|
||||
6.6419083e-01-0.32588708j , -1.6986026e-04-0.4271902j ],
|
||||
[ 2.4243295e-01+0.23515853j , 4.7519770e-01-0.15375356j ,
|
||||
3.5048491e-01-0.16253117j , 8.7767310e-02+0.6924699j ]]],
|
||||
dtype=complex64), array([[13.7465725, 9.692749 , 6.012994 , 1.7378366],
|
||||
[12.048737 , 7.9871097, 3.9069395, 1.7972969]], dtype=float32), array([[[-0.29246047 +0.582183j , -0.5259067 -0.27628872j ,
|
||||
0.3469344 -0.15329583j , -0.19642796 -0.20042528j ],
|
||||
[ 0.02593917 +0.33676162j , 0.55821204 +0.18806678j ,
|
||||
0.20056807 +0.35542676j , -0.5978458 -0.12236632j ],
|
||||
[ 0.46109042 -0.034899022j, 0.24078317 -0.19413127j ,
|
||||
0.19841644 -0.3350929j , 0.17725058 -0.71234804j ],
|
||||
[-0.48504043 -0.1111859j , 0.31423682 +0.32512894j ,
|
||||
0.23620881 -0.69435585j , -0.04027982 +0.091500096j]],
|
||||
|
||||
[[ 0.6148358 +0.14274344j , -0.23763065 +0.3584556j ,
|
||||
-0.13778959 +0.19841002j , 0.234248 +0.55083406j ],
|
||||
[ 0.734292 -0.11843189j , -0.07720101 -0.4717579j ,
|
||||
-0.051302787-0.19603764j , -0.16576114 -0.38695592j ],
|
||||
[ 0.019250087+0.17437238j , -0.43637297 +0.6207004j ,
|
||||
0.033975117-0.27825257j , 0.024782527-0.5606609j ],
|
||||
[-0.060103483+0.11833174j , -0.074751794-0.072441235j,
|
||||
-0.5370554 +0.7304642j , 0.07712744 -0.37893948j ]]],
|
||||
dtype=complex64)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xcomplex<f32>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f32>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f32>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
|
||||
%cst_0 = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f32>>) -> (tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>>, tensor<2x4x4xcomplex<f32>>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xcomplex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%2 = stablehlo.real %1 : (tensor<2x4x4xcomplex<f32>>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%3 = stablehlo.imag %1 : (tensor<2x4x4xcomplex<f32>>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%4 = stablehlo.negate %3 : tensor<2x4x4xf32> loc(#loc3)
|
||||
%5 = stablehlo.complex %2, %4 : tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
|
||||
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%11 = stablehlo.select %10, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%15 = stablehlo.select %14, %0#2, %13 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%16 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%18 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%19 = stablehlo.select %18, %5, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
return %15, %11, %19 : tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01+\x05\x01\x05\x1b\x01\x03\x0b\x03\x19\x0f\x13\x17\x1b\x1f#'+/37;\x03\xbfs9\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x035\x1b\x07\x07\x17\x07\x07\x1b\x0b\x0f\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02j\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x1f\x11\x01\x00\x05!\x05#\x05%\x1d\x15\x03\x05'\x03\x03\x19O\x05)\x05+\x17\x1f\xba\n\x1b\x05-\x1d/\x1d1\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x1d3\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##%\x03\x079=A\r\x05);!#\x1d5\r\x05)?!#\x1d7\r\x05)C!#\x1d9\x1d;\x1d=\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x17\t\x00\x00\xc0\x7f\x1f\x19\t\x00\x00\x00\x00\r\x05Q-S-\x1d?\x1dA\x0b\x03\x1dC\x1dE\x03\x01\x05\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x13\x01\t)\x05\t\x11\t\x1d\x13)\x07\t\x11\x11\t\x03\t)\x01\x13)\x01\t)\x01\x1b\x1b)\x03\t\x1b)\x03\r\r)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\r)\x03\t\x07)\x05\t\x05\x07)\x03\x05\r)\x05\t\x11\x07)\x03\t\r\x04n\x03\x05\x01Q\x03\x07\x01\x07\x04F\x03\x03\x01\x05\tP\x03\x03\x07\x04\x1a\x03\x039c\x03\x0b\x13\x00\x05B\x03\x05\x03\x15\x05B\x03\x07\x03\x17\x05B\x03\t\x03\x19\x0bG\x01\x17\x0b\x0b\x05\x0b\x05\x05\x1d\x03\x01\rF\x01\r\x03\x05\x03\x0f\x0f\x06\x01\x03\x11\x03\x13\x11\x06\x01\x03\x11\x03\x13\x13\x06\x01\x03\x11\x03\x17\x15\x06\x01\x03\x05\x05\x15\x19\x03F\x01\x0f\x03\x1d\x03\x07\x17F\x01\x11\x03/\x05\x11\x1d\x03F\x01\x13\x031\x03\x1f\x03F\x01\x0f\x03\x0b\x03\x05\x03F\x01\x15\x035\x03!\x07\x06\x01\x03\x0b\x07%\x0b#\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x03)\x07\x06\x01\x03\x05\x07-\r+\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x031\x07\x06\x01\x03\x05\x075\x1b3\x19\x04\x03\x07/'7\x06\x03\x01\x05\x01\x00\x8a\x07G)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x17\x15\x11\x11\x1b\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08M\x19\x05#\x01\x0b157EG\x03I\x03K\x03M\x11UWY[]_ae\x03k\x03'\x05mo\x03+\x03q\x03/",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["jacobi"]["c128"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvdj_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[ 3.796399075115019 +0.6837988589198791j ,
|
||||
-3.4038861521636834 -0.8791658549931785j ,
|
||||
5.587025360536978 +1.3741531770045181j ,
|
||||
2.546764576857623 +1.6809085185078099j ],
|
||||
[-0.29341875154987235-0.7960296026166453j ,
|
||||
-3.978920434463012 -3.064639731239574j ,
|
||||
7.797646929978315 +0.8348826539375432j ,
|
||||
1.3257327511547885 +4.524484078738454j ],
|
||||
[ 2.664074451502439 +3.393033766292215j ,
|
||||
-1.9919844260791377 -0.6279428424058938j ,
|
||||
4.9406207893724465 -1.3141491766038624j ,
|
||||
-0.09537336365814258-1.6177405195744095j ],
|
||||
[-2.4688921567972058 +0.9746213770706899j ,
|
||||
0.5921515121270726 +4.164182480017167j ,
|
||||
-2.2589950508277568 +2.6432862086413222j ,
|
||||
-2.556559707542412 +8.972441493886869j ]],
|
||||
|
||||
[[-1.1112549656532051 +1.9704765658574988j ,
|
||||
8.989616960518267 -3.6609418049818268j ,
|
||||
1.2346549691110358 -0.24414962907971388j,
|
||||
-2.9090211908941734 +4.153707428606018j ],
|
||||
[-0.5588638014996116 -1.67573324865607j ,
|
||||
3.149773407631379 +4.604525381223088j ,
|
||||
1.128102476802749 -1.3129091142659617j ,
|
||||
-3.9361491309229013 -0.4879709640370399j ],
|
||||
[-1.7318669937731295 -2.869679468975394j ,
|
||||
-3.0523599142955913 +8.95268082648917j ,
|
||||
-2.6723736459369936 -2.1677507057699845j ,
|
||||
2.2856025738573873 +4.128295675578359j ],
|
||||
[-3.0608685537602445 -0.21903335326027606j,
|
||||
-4.833657640993765 -2.184980999873441j ,
|
||||
-1.4399110875167116 +0.23459254553652992j,
|
||||
1.9463909302306492 -6.990119453626141j ]]]),),
|
||||
expected_outputs=(array([[[ 0.5058374100694445 +0.08736973325822138j ,
|
||||
-0.15357420292088966 +0.15092672888666417j ,
|
||||
0.4000099867967776 +0.03343224253498994j ,
|
||||
-0.08606623675966052 -0.7222174392115995j ],
|
||||
[ 0.5494982003369341 +0.21627530772570236j ,
|
||||
-0.24559859535168804 +0.42796239936478525j ,
|
||||
-0.35789689606765984 -0.3818532162746145j ,
|
||||
0.1870806116265344 +0.3144916716626133j ],
|
||||
[ 0.35367519290177574 -0.09104684890511797j ,
|
||||
-0.2722754551548636 -0.01520202737842482j ,
|
||||
0.20564615603913944 +0.6788851354546692j ,
|
||||
-0.26426814455261194 +0.46823742188539086j ],
|
||||
[-0.26578682435990086 +0.42866473683552114j ,
|
||||
0.3753385929542211 +0.7035065862655021j ,
|
||||
0.11671177248101379 +0.21948854675620316j ,
|
||||
-0.21652589234437378 +0.03351132737269058j ]],
|
||||
|
||||
[[ 0.6447398173517832 -0.15549362721502685j ,
|
||||
0.15703159957366286 -0.18175797380505443j ,
|
||||
-0.26984447689730545 -0.021902292069557842j,
|
||||
-0.5809856180118731 +0.30265058246070115j ],
|
||||
[ 0.1676295953197978 +0.35517400244986735j ,
|
||||
0.1471535132890569 -0.11945200848513432j ,
|
||||
-0.048587988684360706-0.5959669737080945j ,
|
||||
0.4658299942602748 +0.485070920592558j ],
|
||||
[-0.18394077259047054 +0.4607473838424225j ,
|
||||
-0.6953163697460342 -0.28046734855011674j ,
|
||||
-0.31852936346663746 +0.1408757043223249j ,
|
||||
-0.18675394341953266 +0.18859188206317357j ],
|
||||
[-0.38736135459868454 -0.0985538837287597j ,
|
||||
0.1400965142945759 +0.5697616659026354j ,
|
||||
-0.6360840273269505 -0.20798320176325613j ,
|
||||
-0.16572122692263513 +0.14373411782489567j ]]]), array([[14.800219047973494 , 10.208626444252278 , 5.121442071722992 ,
|
||||
2.3317857898198886],
|
||||
[16.387010501961203 , 10.85923071644345 , 4.434400577803048 ,
|
||||
0.7913357405499906]]), array([[[ 0.22661733590854113 +0.12716811836891184j ,
|
||||
-0.24780255554134717 -0.18478518062614943j ,
|
||||
0.7440472412639979 -0.05201706404993294j ,
|
||||
0.5257597465394888 +0.0646977585249859j ],
|
||||
[-0.1730295430553622 +0.08448133745769008j ,
|
||||
0.3682599006560183 +0.4301598551539888j ,
|
||||
-0.2470430370452839 -0.1549815154245142j ,
|
||||
0.6735915557971368 +0.3217074899922104j ],
|
||||
[ 0.9230886222728419 -0.02650327461707979j ,
|
||||
0.26368770031726146 +0.17940633187899951j ,
|
||||
-0.07583142147304134 +0.04327009313242354j ,
|
||||
-0.11210558432341579 +0.15904958085235385j ],
|
||||
[ 0.13986036936906365 +0.15179124656114312j ,
|
||||
-0.2301030912215956 -0.6550803059657956j ,
|
||||
-0.4696876553365799 -0.3611221351860597j ,
|
||||
0.15426668090371212 +0.31702832012208193j ]],
|
||||
|
||||
[[-0.09203102957042177 +0.1296290046576001j ,
|
||||
0.9338316046889822 -0.07199509757761581j ,
|
||||
0.035650554858596174+0.04949401150838972j ,
|
||||
-0.11826036178363568 +0.2824813592971463j ],
|
||||
[ 0.09583662270821797 +0.27782657909919606j ,
|
||||
-0.029480552998465036-0.2320825865452519j ,
|
||||
0.27250121923986637 +0.16010847179743545j ,
|
||||
-0.7541782211683652 -0.4361420378707925j ],
|
||||
[ 0.77179564007941 -0.03313538642869378j ,
|
||||
0.11720446778594507 +0.18064177478753635j ,
|
||||
0.40879934622962677 +0.326378840223041j ,
|
||||
0.2808442732051157 -0.06596695716827156j ],
|
||||
[ 0.5393502219452392 +0.0262529741146748j ,
|
||||
0.14580725795066302 -0.020389102179641145j,
|
||||
-0.6823399052632065 -0.3964340353663506j ,
|
||||
-0.12461789414204025 -0.2201348160267379j ]]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xcomplex<f64>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f64>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f64>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
|
||||
%cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvdj_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f64>>) -> (tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>>, tensor<2x4x4xcomplex<f64>>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.transpose %0#3, dims = [0, 2, 1] : (tensor<2x4x4xcomplex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%2 = stablehlo.real %1 : (tensor<2x4x4xcomplex<f64>>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%3 = stablehlo.imag %1 : (tensor<2x4x4xcomplex<f64>>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%4 = stablehlo.negate %3 : tensor<2x4x4xf64> loc(#loc3)
|
||||
%5 = stablehlo.complex %2, %4 : tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%7 = stablehlo.compare EQ, %0#4, %6, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
|
||||
%10 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%11 = stablehlo.select %10, %0#1, %9 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%15 = stablehlo.select %14, %0#2, %13 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%16 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%18 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%19 = stablehlo.select %18, %5, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
return %15, %11, %19 : tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":686:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01+\x05\x01\x05\x1b\x01\x03\x0b\x03\x19\x0f\x13\x17\x1b\x1f#'+/37;\x03\xbfs9\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0bO/\x1f\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/o\x0b\x0bO\x01\x05\x0b\x0f\x035\x1b\x07\x07\x17\x07\x07\x1b\x0b\x0f\x0f\x0f\x07\x13\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x02\x9a\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x1f\x11\x01\x00\x05!\x05#\x05%\x1d\x15\x03\x05'\x03\x03\x19O\x05)\x05+\x17\x1f\xba\n\x1b\x05-\x1d/\x1d1\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x1d3\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x033\r\x03!##%\x03\x079=A\r\x05);!#\x1d5\r\x05)?!#\x1d7\r\x05)C!#\x1d9\x1d;\x1d=\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x19\t\x00\x00\x00\x00\r\x05Q-S-\x1d?\x1dA\x0b\x03\x1dC\x1dE\x03\x01\x05\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f1\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x13\x01\x0b)\x05\t\x11\t\x1d\x13)\x07\t\x11\x11\t\x03\t)\x01\x13)\x01\t)\x01\x1b\x1b)\x03\t\x1b)\x03\r\r)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\r)\x03\t\x07)\x05\t\x05\x07)\x03\x05\r)\x05\t\x11\x07)\x03\t\r\x04n\x03\x05\x01Q\x03\x07\x01\x07\x04F\x03\x03\x01\x05\tP\x03\x03\x07\x04\x1a\x03\x039c\x03\x0b\x13\x00\x05B\x03\x05\x03\x15\x05B\x03\x07\x03\x17\x05B\x03\t\x03\x19\x0bG\x01\x17\x0b\x0b\x05\x0b\x05\x05\x1d\x03\x01\rF\x01\r\x03\x05\x03\x0f\x0f\x06\x01\x03\x11\x03\x13\x11\x06\x01\x03\x11\x03\x13\x13\x06\x01\x03\x11\x03\x17\x15\x06\x01\x03\x05\x05\x15\x19\x03F\x01\x0f\x03\x1d\x03\x07\x17F\x01\x11\x03/\x05\x11\x1d\x03F\x01\x13\x031\x03\x1f\x03F\x01\x0f\x03\x0b\x03\x05\x03F\x01\x15\x035\x03!\x07\x06\x01\x03\x0b\x07%\x0b#\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x03)\x07\x06\x01\x03\x05\x07-\r+\x03F\x01\x13\x03!\x03\x1f\x03F\x01\x0f\x03\x05\x03\x03\x03F\x01\x17\x03#\x031\x07\x06\x01\x03\x05\x075\x1b3\x19\x04\x03\x07/'7\x06\x03\x01\x05\x01\x00\x8a\x07G)\x03\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x17\x15\x11\x11\x1b\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00\x00cusolver_gesvdj_ffi\x00\x08M\x19\x05#\x01\x0b157EG\x03I\x03K\x03M\x11UWY[]_ae\x03k\x03'\x05mo\x03+\x03q\x03/",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["qr"]["f32"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvd_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[ 7.064613 , 4.4742312 , -0.12700312, -0.71483076],
|
||||
[-0.59317935, 4.0224333 , 0.5515773 , 6.009665 ],
|
||||
[-5.193879 , -1.0297644 , -4.388829 , 2.3485358 ],
|
||||
[-0.8724199 , -1.5610907 , 0.47096923, 0.10478485]],
|
||||
|
||||
[[ 0.7020009 , 0.0506321 , 2.5788887 , -0.44895908],
|
||||
[ 2.617715 , -1.5580447 , -0.9952533 , 1.3444504 ],
|
||||
[ 0.7077899 , -0.9494638 , 2.3607216 , -4.8069396 ],
|
||||
[-0.9731158 , 2.762172 , -3.6125846 , -0.59783787]]],
|
||||
dtype=float32),),
|
||||
expected_outputs=(array([[[-0.7763474 , -0.1832199 , -0.54734766 , 0.25323072 ],
|
||||
[-0.025878029, -0.93325573 , 0.35778683 , 0.018767256],
|
||||
[ 0.6168491 , -0.29107273 , -0.72102463 , 0.12205475 ],
|
||||
[ 0.12693313 , 0.10363644 , 0.22917844 , 0.959492 ]],
|
||||
|
||||
[[-0.35366213 , -0.12650901 , 0.30348226 , 0.8756809 ],
|
||||
[ 0.08954996 , -0.56219155 , -0.7897131 , 0.22863595 ],
|
||||
[-0.7685047 , 0.45641905 , -0.4388071 , -0.09236202 ],
|
||||
[ 0.5256468 , 0.6779513 , -0.30281967 , 0.41518393 ]]],
|
||||
dtype=float32), array([[10.327013 , 7.659154 , 3.736682 , 0.61586076],
|
||||
[ 6.3917613 , 4.5939665 , 2.8317585 , 1.2306355 ]],
|
||||
dtype=float32), array([[[-0.8505677 , -0.42713356 , -0.24819751 , 0.18024875 ],
|
||||
[ 0.088860154, -0.5791473 , 0.108991824, -0.8030028 ],
|
||||
[-0.14292236 , -0.16727918 , 0.94716436 , 0.23338944 ],
|
||||
[ 0.49820864 , -0.6739162 , -0.17145996 , 0.51790595 ]],
|
||||
|
||||
[[-0.16729502 , 0.31668338 , -0.7375667 , 0.5724681 ],
|
||||
[-0.41296393 , 0.5025677 , -0.24780482 , -0.7179688 ],
|
||||
[-0.6604038 , 0.29167938 , 0.5744389 , 0.38575882 ],
|
||||
[ 0.6044337 , 0.7497071 , 0.25418177 , 0.08939337 ]]],
|
||||
dtype=float32)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
|
||||
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x4x4xf32> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3)
|
||||
return %10, %6, %14 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x1f\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xb6\x05\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19O\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x01\x1d)\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##\x1d\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x11\t\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\x00\x00\r\x07Q-S-U/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\t\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xc2\x02\x05\x01Q\x03\x07\x01\x07\x04\x9a\x02\x03\x01\x05\tP\x03\x03\x07\x04n\x02\x03-K\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\x03F\x01\x0b\x03\x17\x03\x05\rF\x01\r\x03'\x05\x0f\x11\x03F\x01\x0f\x03)\x03\x13\x03F\x01\x0b\x03\t\x03\x03\x03F\x01\x11\x03-\x03\x15\x05\x06\x01\x03\t\x07\x19\t\x17\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03\x1d\x05\x06\x01\x03\x05\x07!\x0b\x1f\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03%\x05\x06\x01\x03\x05\x07)\r'\x0f\x04\x03\x07#\x1b+\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08E\x15\x05#\x01\x0b379GI\x03K\x03M\x11WY[]/_ae\x03'\x05km\x03+\x03o\x031",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["qr"]["f64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvd_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[ 1.6666730531514784 , -0.283751211500066 ,
|
||||
3.0028872858127387 , -3.539066961520449 ],
|
||||
[-3.5009517179281326 , -3.0927716333012025 ,
|
||||
0.7807364288380038 , -1.7085927853808216 ],
|
||||
[ 1.6481964138894265 , -1.0457448512148775 ,
|
||||
6.119638350643893 , 0.6798789946663015 ],
|
||||
[ 0.45137958693199876 , -3.0487560288436093 ,
|
||||
-3.5653048640383225 , 3.1078238891060703 ]],
|
||||
|
||||
[[-0.18274271852634477 , -4.107847311422953 ,
|
||||
-6.970910660834766 , 0.026925818090434366],
|
||||
[ 0.609826504319294 , -0.08188345529211022 ,
|
||||
0.4988730098060886 , 0.5371129476436916 ],
|
||||
[-1.8478868672212718 , 5.777430685108351 ,
|
||||
3.6156805021156426 , -5.4328316768756855 ],
|
||||
[ 2.839181461365762 , -5.931652277530413 ,
|
||||
-6.898420189145518 , 0.33823029148249784 ]]]),),
|
||||
expected_outputs=(array([[[ 0.5185598212776665 , 0.08640745162480469 ,
|
||||
-0.22604389271595665 , 0.8200814731634898 ],
|
||||
[ 0.058783626848861056, 0.9889125760720038 ,
|
||||
0.040834454253227216, -0.13011129638493632 ],
|
||||
[ 0.6559139736574813 , -0.09584764795840599 ,
|
||||
0.7197537968326447 , -0.20626332559785515 ],
|
||||
[-0.5453595659120885 , 0.07347719082261084 ,
|
||||
0.6551268410442749 , 0.5176802762713159 ]],
|
||||
|
||||
[[ 0.5266925395219686 , 0.49814307967604127 ,
|
||||
0.6468374971059552 , 0.23674816434446275 ],
|
||||
[-0.004398828045962558, -0.15543570506880638 ,
|
||||
-0.22846771664609714 , 0.9610530132891235 ],
|
||||
[-0.5473983592728412 , 0.8074523783002769 ,
|
||||
-0.20514413534806356 , 0.07931946025360624 ],
|
||||
[ 0.6503311890022828 , 0.27516153535306986 ,
|
||||
-0.6980828307017146 , -0.11847293172915115 ]]]), array([[ 8.364583718088932 , 5.052514677781009 , 4.631967246026687 ,
|
||||
2.5284502274249308 ],
|
||||
[14.318697073130886 , 5.219146075652053 , 2.112106833842507 ,
|
||||
0.15087576902549965]]), array([[[ 0.17853631156202626 , 0.0774459691431368 ,
|
||||
0.9039781164136 , -0.3807236167650057 ],
|
||||
[-0.6814293626119452 , -0.6346900565025612 ,
|
||||
0.036225601414469004, -0.3626434361038585 ],
|
||||
[ 0.20775315131153385 , -0.6071183144157039 ,
|
||||
0.30699755938819207 , 0.7028502535752935 ],
|
||||
[ 0.6786880265219302 , -0.4717817359132189 ,
|
||||
-0.29540441665143263 , -0.47910415040709536 ]],
|
||||
|
||||
[[ 0.19268560067166468 , -0.641350735371731 ,
|
||||
-0.7081088748435267 , 0.22388236844328566 ],
|
||||
[-0.1718035739905713 , 0.19146225969939276 ,
|
||||
-0.4845129100140278 , -0.8361058396546505 ],
|
||||
[-0.8808414232069264 , 0.1501707384172312 ,
|
||||
-0.25997252143926985 , 0.3660347313883319 ],
|
||||
[ 0.39683016319410813 , 0.7276401491617723 ,
|
||||
-0.4429936224079104 , 0.34179275213656773 ]]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
|
||||
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<2x4x4xf64> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3)
|
||||
return %10, %6, %14 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xb7q3\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03Q\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xc6\x05\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19O\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x01\x1d)\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##\x1d\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\t\x00\x00\x00\x00\r\x07Q-S-U/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03c\x15\x03\x01\x01\x01\x03\x0b%g%%i\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\r\x01)\x05\t\x11\r\x1d\x0b\x13)\x01\r)\x01\x15\x1b)\x03\t\x15)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xc2\x02\x05\x01Q\x03\x07\x01\x07\x04\x9a\x02\x03\x01\x05\tP\x03\x03\x07\x04n\x02\x03-K\x03\x0b\x13\x00\x07B\x03\x05\x03\x11\x07B\x03\x07\x03\x13\x0bG\x01\x17\t\x0b\x05\t\x05\x05\x17\x03\x01\x03F\x01\x0b\x03\x17\x03\x05\rF\x01\r\x03'\x05\x0f\x11\x03F\x01\x0f\x03)\x03\x13\x03F\x01\x0b\x03\t\x03\x03\x03F\x01\x11\x03-\x03\x15\x05\x06\x01\x03\t\x07\x19\t\x17\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03\x1d\x05\x06\x01\x03\x05\x07!\x0b\x1f\x03F\x01\x0f\x03\x19\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x13\x03\x1b\x03%\x05\x06\x01\x03\x05\x07)\r'\x0f\x04\x03\x07#\x1b+\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x19\x15)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00select_v1\x00constant_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08E\x15\x05#\x01\x0b379GI\x03K\x03M\x11WY[]/_ae\x03'\x05km\x03+\x03o\x031",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["qr"]["c64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvd_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[ 1.0732381 +3.5808065j , 3.9696057 +0.20698753j,
|
||||
-0.6425436 +0.5669031j , 2.2608232 +3.3495777j ],
|
||||
[-0.14261106+1.9452835j , 0.7328605 -3.497075j ,
|
||||
-0.2833068 +2.5005085j , -5.3383408 -0.13752732j],
|
||||
[ 0.378204 +0.31973448j, -0.82705253-1.2925721j ,
|
||||
4.6363106 -0.6026158j , -4.2700663 +0.3752707j ],
|
||||
[-0.5445609 -2.3843932j , -1.5469118 -0.22753051j,
|
||||
-1.2669541 -5.0028024j , 0.03133653-3.4463193j ]],
|
||||
|
||||
[[ 2.5795584 +0.6289093j , 1.27082 +3.8879561j ,
|
||||
1.9604253 +1.1865004j , -2.399359 +1.3273407j ],
|
||||
[-1.4925842 +5.878235j , -5.6121607 -3.4182298j ,
|
||||
-1.9998045 +0.10950515j, 1.9120374 +3.3194423j ],
|
||||
[ 0.40499273-2.4316337j , 1.6648822 +2.6184802j ,
|
||||
-4.0471325 +3.133329j , -0.76832575-1.7445682j ],
|
||||
[-1.895143 +0.5848787j , 2.6240315 +5.021989j ,
|
||||
1.449456 +3.472618j , -2.0201976 -3.5186274j ]]],
|
||||
dtype=complex64),),
|
||||
expected_outputs=(array([[[-0.25113735 -0.21140018j , 0.018947408 -0.64372665j ,
|
||||
-0.070043676 -0.5433938j , 0.21103051 +0.3643902j ],
|
||||
[ 0.34577918 -0.55255175j , -0.21863852 +0.16678011j ,
|
||||
-0.48150375 -0.24859619j , -0.45313844 +0.022904329j],
|
||||
[-0.07727728 -0.37749314j , -0.107260175 +0.56208193j ,
|
||||
0.30505398 -0.3713933j , 0.53557503 -0.07908653j ],
|
||||
[ 0.14142953 +0.5467069j , 0.13847762 +0.4037597j ,
|
||||
-0.24584742 -0.33873183j , -0.0011076198+0.56897277j ]],
|
||||
|
||||
[[-0.15231489 +0.2637356j , -0.333622 -0.3903981j ,
|
||||
-0.3809538 -0.1761738j , -0.40208697 -0.5528941j ],
|
||||
[ 0.08597728 -0.6945265j , 0.21902993 -0.445389j ,
|
||||
-0.33454537 -0.015148819j, -0.28255132 +0.26816013j ],
|
||||
[-0.27464584 +0.24423108j , 0.3478235 +0.3221502j ,
|
||||
-0.740009 -0.16378604j , 0.15596838 +0.20345287j ],
|
||||
[-0.25253323 +0.4675811j , 0.37721652 -0.35055056j ,
|
||||
0.3376375 -0.15247335j , -0.38508245 +0.40851057j ]]],
|
||||
dtype=complex64), array([[ 9.668089 , 8.574805 , 4.549492 , 0.5780793],
|
||||
[12.876013 , 7.7651014, 4.0119534, 2.2829206]], dtype=float32), array([[[-0.38075778 +0.j , 0.14002012 +0.060418203j ,
|
||||
-0.4637056 +0.22876605j , -0.489978 -0.5695017j ],
|
||||
[-0.32981405 +0.j , -0.20355047 +0.51292306j ,
|
||||
-0.34164208 -0.4227407j , -0.19677992 +0.50254196j ],
|
||||
[-0.3292037 +0.j , 0.17828293 +0.62404454j ,
|
||||
0.63654786 +0.14848639j , 0.07557414 -0.19353445j ],
|
||||
[ 0.7986684 +0.j , 0.056182697 +0.4978434j ,
|
||||
-0.099771045 -0.0043060416j, -0.28370282 -0.14375056j ]],
|
||||
|
||||
[[-0.3410219 +0.j , 0.35656637 -0.6787797j ,
|
||||
0.22528769 -0.27213925j , -0.21556833 +0.35289994j ],
|
||||
[-0.72291875 +0.j , -0.12834571 -0.11083051j ,
|
||||
-0.3442186 +0.47835365j , -0.14619395 -0.28275603j ],
|
||||
[-0.3274419 +0.j , -0.19452412 +0.057822693j ,
|
||||
0.53667945 -0.43909317j , 0.17421937 -0.5834539j ],
|
||||
[ 0.50385934 -0.j , -0.06922947 -0.58084977j ,
|
||||
0.0073776054+0.21678242j , -0.24243501 -0.546006j ]]],
|
||||
dtype=complex64)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xcomplex<f32>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f32>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f32>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
|
||||
%cst_0 = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f32>>) -> (tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>>, tensor<2x4x4xcomplex<f32>>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<2x4xf32> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f32>> loc(#loc3)
|
||||
return %10, %6, %14 : tensor<2x4x4xcomplex<f32>>, tensor<2x4xf32>, tensor<2x4x4xcomplex<f32>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xbds7\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b/\x1f\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x033\x1b\x07\x17\x07\x07\x07\x0b\x0f\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xf6\x05\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19Q\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1d)\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##!\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x15\t\x00\x00\xc0\x7f\x1f\x17\t\x00\x00\x00\x00\r\x07S-U-W/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03e\x15\x03\x01\x01\x01\x03\x0b%i%%k\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11\x01)\x05\t\x11\r\x1d\t\x13\x03\r)\x01\x11)\x01\r)\x01\x19\x1b)\x03\t\x19)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xda\x02\x05\x01Q\x03\x07\x01\x07\x04\xb2\x02\x03\x01\x05\tP\x03\x03\x07\x04\x86\x02\x03/O\x03\x0b\x13\x00\x05B\x03\x05\x03\x13\x05B\x03\x07\x03\x15\x05B\x03\t\x03\x17\x0bG\x01\x17\x0b\x0b\x05\t\x05\x05\x1b\x03\x01\x03F\x01\r\x03\x1b\x03\x07\rF\x01\x0f\x03+\x05\x11\x13\x03F\x01\x11\x03-\x03\x15\x03F\x01\r\x03\t\x03\x05\x03F\x01\x13\x031\x03\x17\x07\x06\x01\x03\t\x07\x1b\x0b\x19\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03\x1f\x07\x06\x01\x03\x05\x07#\r!\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03'\x07\x06\x01\x03\x05\x07+\x0f)\x0f\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08I\x17\x05#\x01\x0b379GI\x03K\x03M\x03O\x11Y[]_/acg\x03'\x05mo\x03+\x03q\x031",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
data_2024_10_08["qr"]["c128"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cusolver_gesvd_ffi'],
|
||||
serialized_date=datetime.date(2024, 10, 8),
|
||||
inputs=(array([[[-4.699732537674587 -0.18164091805215468j,
|
||||
-2.6529672987457267 +2.1873545441571416j ,
|
||||
-4.583623180305504 +0.05141217967419283j,
|
||||
-1.4684446379730842 +0.5956695859134695j ],
|
||||
[ 2.217429580673316 -1.6820541069935535j ,
|
||||
-1.489637886109648 -1.1907523648513954j ,
|
||||
-5.37070728884717 +0.3011497067658051j ,
|
||||
-3.5377553884933244 +1.560799473477663j ],
|
||||
[ 0.4865985561509131 +4.547548126143047j ,
|
||||
1.9744285723487844 +1.579347193702052j ,
|
||||
3.662108610237921 -3.8947365367486944j ,
|
||||
-0.46900368026456773+3.897268760016375j ],
|
||||
[-3.9057171822032837 +0.894017787659835j ,
|
||||
-2.665956542656175 -5.446062606216615j ,
|
||||
6.586068520522582 +7.82920032979931j ,
|
||||
0.2438426632437082 -2.5324000439269967j ]],
|
||||
|
||||
[[ 2.3593407739528036 +0.1518669531658939j ,
|
||||
0.6163481796609258 +2.2151855304705617j ,
|
||||
1.1710769743888314 -6.27345033430341j ,
|
||||
0.9738490103384626 +0.5395897278168652j ],
|
||||
[-2.4788654273898656 +0.4265527313512031j ,
|
||||
-1.1807578044484868 -0.0496832499163036j ,
|
||||
-4.4976038167764765 +1.058853052811918j ,
|
||||
-1.1727797045618331 -5.283007446632174j ],
|
||||
[ 2.1607883932036422 +0.15328185326939148j,
|
||||
0.33959787374719413-0.44019437888510504j,
|
||||
5.554548585416958 -5.5054723821239575j ,
|
||||
3.6501512907075853 +2.5205805340930167j ],
|
||||
[ 1.3385284474824868 -5.630140770855095j ,
|
||||
-0.27414990799969 -0.46452124262376304j,
|
||||
1.611578799750626 +8.022764935423794j ,
|
||||
-2.616414597337455 +0.02175053549931295j]]]),),
|
||||
expected_outputs=(array([[[ 0.010550471640845436-0.1718885244537855j ,
|
||||
0.7192461888739357 +0.05423301579908536j ,
|
||||
-0.3986393309169209 +0.0992550693722462j ,
|
||||
0.2704504005919772 -0.45626573218528016j ],
|
||||
[ 0.1757158215796512 -0.288981687849456j ,
|
||||
0.2705184461915344 +0.2220238094575261j ,
|
||||
0.8080741699984124 +0.3193834688705089j ,
|
||||
0.08837920732925486 -0.018389772805120566j],
|
||||
[ 0.18559809104086944 +0.39876558214096686j ,
|
||||
-0.035649910397314105-0.49144324300371245j ,
|
||||
0.15413388674267017 +0.0406886784503357j ,
|
||||
0.7321684611544159 +0.047628802113382905j],
|
||||
[-0.8024624385313128 +0.13619820370204241j ,
|
||||
-0.23383020362294066 +0.2445505198523447j ,
|
||||
0.12332909849753147 +0.18873939840686926j ,
|
||||
0.26625208768341435 -0.3182762352506706j ]],
|
||||
|
||||
[[-0.19373811197499574 -0.3678054849282306j ,
|
||||
0.2636906578352804 +0.07973830233400031j ,
|
||||
-0.22842504880430503 +0.7133040647598473j ,
|
||||
-0.3539448620401726 +0.25502167015286226j ],
|
||||
[-0.24171920617553377 +0.3311681855178457j ,
|
||||
-0.516368216465723 +0.1379819660618946j ,
|
||||
-0.45968933418302504 +0.04091891748552965j ,
|
||||
0.20304338982129244 +0.5403786083964173j ],
|
||||
[ 0.02755060979123564 -0.5776868042861196j ,
|
||||
0.27416989428331057 -0.1571567241856725j ,
|
||||
-0.2065632715230957 -0.2247313731964578j ,
|
||||
0.630967059955887 +0.27268947023173257j ],
|
||||
[ 0.4031484266707497 +0.40258464155812185j ,
|
||||
0.22887811112799977 -0.6972670403334557j ,
|
||||
-0.29285414368652807 +0.2170127687269824j ,
|
||||
0.03733951130463506 +0.050775060769343766j]]]), array([[15.105031148122244 , 9.10491991264034 , 5.006211740104105 ,
|
||||
3.446376589720919 ],
|
||||
[15.343823952995173 , 7.3753715646873195, 3.7496815109995807,
|
||||
0.8625145657311305]]), array([[[ 0.398346102099369 +0.j ,
|
||||
0.1371865428555439 +0.20963212142757817j ,
|
||||
-0.40914166828555465 -0.7712183712683253j ,
|
||||
-0.01748398034296661 +0.12678422058548106j ],
|
||||
[-0.4705168176319191 +0.j ,
|
||||
-0.44062591598236445 +0.5013959991714195j ,
|
||||
-0.27697957785810795 +0.006226163227350533j,
|
||||
-0.46230452584693416 +0.2063561296099042j ],
|
||||
[ 0.6106772453975169 +0.j ,
|
||||
-0.2591680391544425 -0.21982460883707244j ,
|
||||
0.0568261232179787 +0.2729248746486013j ,
|
||||
-0.41495999423821034 +0.511540202216849j ],
|
||||
[-0.49699860082446057 +0.j ,
|
||||
0.2086557264813747 -0.5767641952667593j ,
|
||||
0.0041166738594973 -0.2886775449288538j ,
|
||||
-0.0862160321999865 +0.5348021741590255j ]],
|
||||
|
||||
[[-0.09961676341576099 +0.j ,
|
||||
-0.045561706463074156+0.0200549951728976j ,
|
||||
0.6993928144861429 +0.5554242311281644j ,
|
||||
-0.27729760005165593 +0.3362411099877399j ],
|
||||
[ 0.9183964339371264 +0.j ,
|
||||
0.18513592789573766 +0.048643367601494444j,
|
||||
-0.07592192157413628 +0.08808182166509272j ,
|
||||
0.022653328707496152+0.3253779069593685j ],
|
||||
[-0.36489390866852983 +0.j ,
|
||||
0.5302626885445338 -0.13646920230359977j ,
|
||||
-0.339383962926636 -0.005000573360891423j,
|
||||
-0.01709833277905625 +0.6719756248311359j ],
|
||||
[ 0.11609016321265409 +0.j ,
|
||||
0.16300036810869795 -0.7965607051693728j ,
|
||||
0.13402108640742152 -0.23592994174105483j ,
|
||||
-0.4709038361949722 -0.17340699243933513j ]]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("operand")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4x4xcomplex<f64>> {mhlo.layout_mode = "default"} loc("operand")) -> (tensor<2x4x4xcomplex<f64>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex<f64>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
|
||||
%cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
|
||||
%cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%arg0) {mhlo.backend_config = {compute_uv = true, full_matrices = true, transposed = false}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex<f64>>) -> (tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>>, tensor<2x4x4xcomplex<f64>>, tensor<2xi32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2xi32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f64>) -> tensor<2x4xf64> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3)
|
||||
%12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
%13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3)
|
||||
%14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex<f64>> loc(#loc3)
|
||||
return %10, %6, %14 : tensor<2x4x4xcomplex<f64>>, tensor<2x4xf64>, tensor<2x4x4xcomplex<f64>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":690:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/svd"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.0\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xbds7\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03S\x0b\x0bo\x0f\x0b/\x0b\x0bo\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0bO/\x1f#\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1fO/\x0b\x0bO\x01\x05\x0b\x0f\x033\x1b\x07\x17\x07\x07\x07\x0b\x0f\x0f\x0f\x07\x13\x1b\x1b\x1f\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02&\x06\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19Q\x05\x1f\x05!\x17\x1f\xca\n\x1b\x05#\x1d%\x1d'\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1d)\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x05\x03\x05\x01\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x035\r\x03!##!\x03\x07;?C\r\x05)=!#\x1d+\r\x05)A!#\x1d-\r\x05)E!#\x1d/\x1d1\x1d3\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x15\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17\t\x00\x00\x00\x00\r\x07S-U-W/\x1d5\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x03\x03%\x03\x03e\x15\x03\x01\x01\x01\x03\x0b%i%%k\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11\x01)\x05\t\x11\r\x1d\x0b\x13\x03\r)\x01\x11)\x01\r)\x01\x19\x1b)\x03\t\x19)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\t\x05)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0b)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0b)\x05\t\x11\x07)\x03\t\x0b)\x03\r\x0b\x04\xda\x02\x05\x01Q\x03\x07\x01\x07\x04\xb2\x02\x03\x01\x05\tP\x03\x03\x07\x04\x86\x02\x03/O\x03\x0b\x13\x00\x05B\x03\x05\x03\x13\x05B\x03\x07\x03\x15\x05B\x03\t\x03\x17\x0bG\x01\x17\x0b\x0b\x05\t\x05\x05\x1b\x03\x01\x03F\x01\r\x03\x1b\x03\x07\rF\x01\x0f\x03+\x05\x11\x13\x03F\x01\x11\x03-\x03\x15\x03F\x01\r\x03\t\x03\x05\x03F\x01\x13\x031\x03\x17\x07\x06\x01\x03\t\x07\x1b\x0b\x19\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03\x1f\x07\x06\x01\x03\x05\x07#\r!\x03F\x01\x11\x03\x1d\x03\x15\x03F\x01\r\x03\x05\x03\x03\x03F\x01\x15\x03\x1f\x03'\x07\x06\x01\x03\x05\x07+\x0f)\x0f\x04\x03\x07%\x1d-\x06\x03\x01\x05\x01\x00\xda\x06?'\x03\x17\x1d\x17\x0f\x0b\t\t\t!\x11#i1)\x11\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00operand\x00mhlo.backend_config\x00jit(func)/jit(main)/svd\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00compute_uv\x00full_matrices\x00transposed\x00\x00cusolver_gesvd_ffi\x00\x08I\x17\x05#\x01\x0b379GI\x03K\x03M\x03O\x11Y[]_/acg\x03'\x05mo\x03+\x03q\x031",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import enum
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
@ -308,6 +309,13 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
|
||||
return q, r
|
||||
|
||||
|
||||
class SvdAlgorithm(enum.Enum):
|
||||
"""Enum for SVD algorithm."""
|
||||
DEFAULT = "default"
|
||||
QR = "QR"
|
||||
JACOBI = "Jacobi"
|
||||
|
||||
|
||||
@overload
|
||||
def svd(
|
||||
x: ArrayLike,
|
||||
@ -315,6 +323,7 @@ def svd(
|
||||
full_matrices: bool = True,
|
||||
compute_uv: Literal[True],
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
algorithm: SvdAlgorithm | None = None,
|
||||
) -> tuple[Array, Array, Array]:
|
||||
...
|
||||
|
||||
@ -326,6 +335,7 @@ def svd(
|
||||
full_matrices: bool = True,
|
||||
compute_uv: Literal[False],
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
algorithm: SvdAlgorithm | None = None,
|
||||
) -> Array:
|
||||
...
|
||||
|
||||
@ -337,6 +347,7 @@ def svd(
|
||||
full_matrices: bool = True,
|
||||
compute_uv: bool = True,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
algorithm: SvdAlgorithm | None = None,
|
||||
) -> Array | tuple[Array, Array, Array]:
|
||||
...
|
||||
|
||||
@ -348,6 +359,7 @@ def svd(
|
||||
full_matrices: bool = True,
|
||||
compute_uv: bool = True,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
algorithm: SvdAlgorithm | None = None,
|
||||
) -> Array | tuple[Array, Array, Array]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
@ -360,6 +372,7 @@ def svd(
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
if compute_uv:
|
||||
s, u, v = result
|
||||
@ -1062,15 +1075,12 @@ batching.primitive_batchers[eigh_p] = _eigh_batching_rule
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cpu'),
|
||||
platform='cpu')
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'),
|
||||
platform='rocm')
|
||||
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'),
|
||||
platform='rocm')
|
||||
mlir.register_lowering(
|
||||
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
|
||||
platform='tpu')
|
||||
@ -1890,22 +1900,26 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering))
|
||||
|
||||
|
||||
# Singular value decomposition
|
||||
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None):
|
||||
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None,
|
||||
algorithm=None):
|
||||
return dispatch.apply_primitive(
|
||||
svd_p,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
|
||||
|
||||
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
|
||||
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
|
||||
algorithm=None):
|
||||
del algorithm # unused
|
||||
if isinstance(operand, ShapedArray):
|
||||
batch_dims = operand.shape[:-2]
|
||||
m = operand.shape[-2]
|
||||
n = operand.shape[-1]
|
||||
rank = min(m, n)
|
||||
rank = core.min_dim(m, n)
|
||||
if subset_by_index is not None:
|
||||
if full_matrices and subset_by_index != (0, rank):
|
||||
raise ValueError("full_matrices and subset_by_index cannot both be set")
|
||||
@ -1927,12 +1941,14 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
|
||||
|
||||
@config.default_matmul_precision("float32")
|
||||
def _svd_jvp_rule(
|
||||
primals, tangents, *, full_matrices, compute_uv, subset_by_index
|
||||
primals, tangents, *, full_matrices, compute_uv, subset_by_index,
|
||||
algorithm=None,
|
||||
):
|
||||
A, = primals
|
||||
dA, = tangents
|
||||
s, U, Vt = svd_p.bind(
|
||||
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index
|
||||
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
|
||||
if compute_uv and full_matrices:
|
||||
@ -1993,24 +2009,18 @@ def _empty_svd(a, *, full_matrices, compute_uv):
|
||||
|
||||
|
||||
def _svd_cpu_gpu_lowering(
|
||||
gesvd_impl,
|
||||
ctx,
|
||||
operand,
|
||||
*,
|
||||
full_matrices,
|
||||
compute_uv,
|
||||
subset_by_index,
|
||||
platform: str,
|
||||
target_name_prefix: str,
|
||||
algorithm=None,
|
||||
):
|
||||
operand_aval, = ctx.avals_in
|
||||
s_aval = ctx.avals_out[0]
|
||||
m, n = operand_aval.shape[-2:]
|
||||
# Since the last two dimensions (m, n) are used to compute the workspace
|
||||
# size, we support dynamic dimensions only for the batch size for now.
|
||||
if not is_constant_shape([m, n]):
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for native serialization for svd on CPU and GPU is "
|
||||
f"implemented only for the batch dimensions: {operand_aval.shape}")
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
if not (subset_by_index is None or subset_by_index == (0, min(m, n))):
|
||||
@ -2023,22 +2033,43 @@ def _svd_cpu_gpu_lowering(
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
)
|
||||
|
||||
if platform in ["cuda", "rocm"]:
|
||||
if not is_constant_shape(operand_aval.shape):
|
||||
# TODO(necula): remove the platform kwarg when we implement GPU support.
|
||||
if target_name_prefix == "cpu":
|
||||
if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT:
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for native serialization for SVD is not "
|
||||
f"implemented, try to upgrade jaxlib; b/261671778; {operand_aval.shape}")
|
||||
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
"The SVD algorithm parameter is not implemented on CPU.")
|
||||
target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype)
|
||||
nb = len(batch_dims)
|
||||
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
|
||||
result_layouts = [layout, tuple(range(nb, -1, -1)), layout, layout,
|
||||
tuple(range(nb - 1, -1, -1))]
|
||||
mode = lapack._svd_computation_attr(compute_uv=compute_uv,
|
||||
full_matrices=full_matrices)
|
||||
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
|
||||
result_layouts=result_layouts,
|
||||
operand_output_aliases={0: 0})
|
||||
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
|
||||
if compute_uv:
|
||||
s_aval, u_aval, vt_aval = ctx.avals_out
|
||||
else:
|
||||
s_aval, = ctx.avals_out
|
||||
# TODO(danfm): It should be possible to skip instantiating these arrays
|
||||
# when they are not used.
|
||||
u_aval = ShapedArray((*batch_dims, m,
|
||||
m if full_matrices else core.min_dim(m, n)),
|
||||
operand_aval.dtype)
|
||||
vt_aval = ShapedArray((*batch_dims,
|
||||
n if full_matrices else core.min_dim(m, n), n),
|
||||
operand_aval.dtype)
|
||||
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, u_aval, vt_aval,
|
||||
info_aval])
|
||||
_, s, u, vt, info = rule(sub_ctx, operand, mode=mode)
|
||||
else:
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
s, u, vt, info = gesvd_impl(ctx, operand_aval.dtype, operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
a_shape_vals=a_shape_vals)
|
||||
s, u, vt, info = _svd_gpu_sub_lowering(ctx, operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
target_name_prefix=target_name_prefix,
|
||||
algorithm=algorithm)
|
||||
|
||||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
|
||||
select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
||||
@ -2071,9 +2102,115 @@ def _svd_cpu_gpu_lowering(
|
||||
return result
|
||||
|
||||
|
||||
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
|
||||
batch_dims = a.shape[:-2]
|
||||
def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv,
|
||||
target_name_prefix, algorithm):
|
||||
operand_aval, = ctx.avals_in
|
||||
if compute_uv:
|
||||
s_aval, u_aval, vt_aval = ctx.avals_out
|
||||
else:
|
||||
s_aval, = ctx.avals_out
|
||||
u_aval = vt_aval = ShapedArray((), operand_aval.dtype)
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
|
||||
nb = len(batch_dims)
|
||||
m, n = operand_aval.shape[-2:]
|
||||
k = core.min_dim(m, n)
|
||||
|
||||
transposed = False
|
||||
kwargs = {}
|
||||
|
||||
# The Jacobi algorithm appears to outperform the default QR algorithm for
|
||||
# small to medium sized matrices. See:
|
||||
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
|
||||
# slide 5. With this in mind, we default to using the Jacobi algorithm for
|
||||
# matrices smaller than 1024x1024.
|
||||
#
|
||||
# Note that the Jacobi algorithm is only used by default for matrices with
|
||||
# concrete matrix dimensions. When using dynamic shapes, we always use the
|
||||
# default QR algorithm, but users can (in principle) override this behavior
|
||||
# by passing `use_jacobi=True`.
|
||||
#
|
||||
# TODO(danfm): Since this was originally implemented, hipSolver appers to
|
||||
# have added support for the Jacobi algorithm, so we should investigate
|
||||
# removing this condition.
|
||||
if algorithm is None or algorithm == SvdAlgorithm.DEFAULT:
|
||||
try:
|
||||
use_jacobi = target_name_prefix == "cu" and m <= 1024 and n <= 1024
|
||||
except core.InconclusiveDimensionOperation:
|
||||
use_jacobi = False
|
||||
else:
|
||||
use_jacobi = algorithm == SvdAlgorithm.JACOBI
|
||||
if use_jacobi:
|
||||
target_name = f"{target_name_prefix}solver_gesvdj_ffi"
|
||||
# The gesvdjbatched kernel doesn't support "econ" mode, but it also only
|
||||
# supports matrices up to 32x32, so it's always worth using the batched
|
||||
# version and then slicing afterwards when the matrix is small enough.
|
||||
try:
|
||||
econ = not full_matrices and m > 32 and n > 32
|
||||
except core.InconclusiveDimensionOperation:
|
||||
econ = False
|
||||
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
|
||||
else:
|
||||
target_name = f"{target_name_prefix}solver_gesvd_ffi"
|
||||
econ = not full_matrices
|
||||
# Because the base gesvd kernel only supports matrices where m >= n, we.
|
||||
transposed = m < n
|
||||
kwargs = {"transposed": transposed}
|
||||
if transposed:
|
||||
layout = tuple(range(nb + 1, -1, -1))
|
||||
else:
|
||||
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
|
||||
|
||||
result_layouts = [layout, tuple(range(nb, -1, -1)),
|
||||
layout if use_jacobi or compute_uv else (),
|
||||
layout if use_jacobi or compute_uv else (),
|
||||
tuple(range(nb - 1, -1, -1))]
|
||||
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
|
||||
result_layouts=result_layouts,
|
||||
operand_output_aliases={0: 0})
|
||||
if use_jacobi:
|
||||
# When using the Jacobi algorithm, the U and V matrices must always be
|
||||
# allocated even if compute_uv is False.
|
||||
u_aval = ShapedArray((*batch_dims, m, k if econ else m), u_aval.dtype)
|
||||
v_aval = ShapedArray((*batch_dims, n, k if econ else n), vt_aval.dtype)
|
||||
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, u_aval, v_aval,
|
||||
info_aval])
|
||||
elif transposed:
|
||||
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, vt_aval, u_aval,
|
||||
info_aval])
|
||||
else:
|
||||
sub_ctx = ctx.replace(avals_out=[operand_aval, s_aval, u_aval, vt_aval,
|
||||
info_aval])
|
||||
_, s, u, vt, info = rule(sub_ctx, operand, full_matrices=not econ,
|
||||
compute_uv=compute_uv, **kwargs)
|
||||
if use_jacobi and compute_uv:
|
||||
vt = hlo.transpose(
|
||||
vt,
|
||||
mlir.dense_int_array(np.array(tuple(range(nb)) + (nb + 1, nb))))
|
||||
if np.issubdtype(operand_aval.dtype, np.complexfloating):
|
||||
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
|
||||
if not full_matrices and not econ:
|
||||
nd = len(operand_aval.shape)
|
||||
u = mlir.slice_op(ctx, u, ctx.avals_out[1],
|
||||
start_indices=np.zeros([nd], np.int64),
|
||||
limit_indices=batch_dims + (m, k),
|
||||
strides=np.ones([nd], np.int64))
|
||||
vt = mlir.slice_op(ctx, vt, ctx.avals_out[2],
|
||||
start_indices=np.zeros([nd], np.int64),
|
||||
limit_indices=batch_dims + (k, n),
|
||||
strides=np.ones([nd], np.int64))
|
||||
if transposed:
|
||||
return s, vt, u, info
|
||||
else:
|
||||
return s, u, vt, info
|
||||
|
||||
|
||||
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None):
|
||||
if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT:
|
||||
raise NotImplementedError(
|
||||
"The SVD algorithm parameter is not implemented on TPU.")
|
||||
|
||||
batch_dims = a.shape[:-2]
|
||||
fn = partial(
|
||||
lax_svd.svd,
|
||||
full_matrices=full_matrices,
|
||||
@ -2092,8 +2229,9 @@ def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
|
||||
|
||||
|
||||
def _svd_tpu_lowering_rule(
|
||||
ctx, operand, *, full_matrices, compute_uv, subset_by_index
|
||||
ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None
|
||||
):
|
||||
del algorithm # unused
|
||||
operand_aval, = ctx.avals_in
|
||||
m, n = operand_aval.shape[-2:]
|
||||
|
||||
@ -2115,7 +2253,8 @@ def _svd_tpu_lowering_rule(
|
||||
|
||||
|
||||
def _svd_batching_rule(
|
||||
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index
|
||||
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index,
|
||||
algorithm=None,
|
||||
):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
@ -2125,6 +2264,7 @@ def _svd_batching_rule(
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
|
||||
if compute_uv:
|
||||
@ -2141,18 +2281,14 @@ ad.primitive_jvps[svd_p] = _svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = _svd_batching_rule
|
||||
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo,
|
||||
platform='cpu'),
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, target_name_prefix='cpu'),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd,
|
||||
platform='cuda'),
|
||||
platform='cuda')
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, target_name_prefix='cu'),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd,
|
||||
platform='rocm'),
|
||||
platform='rocm')
|
||||
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, target_name_prefix='hip'),
|
||||
platform='rocm')
|
||||
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
|
||||
|
||||
|
||||
|
@ -3314,6 +3314,7 @@ def _svd(
|
||||
full_matrices: bool,
|
||||
compute_uv: bool,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
algorithm: lax.linalg.SvdAlgorithm | None = None,
|
||||
):
|
||||
if not (
|
||||
subset_by_index is None
|
||||
@ -3321,6 +3322,9 @@ def _svd(
|
||||
):
|
||||
raise NotImplementedError("subset_by_index is not implemented")
|
||||
|
||||
if algorithm is not None and algorithm != lax.linalg.SvdAlgorithm.DEFAULT:
|
||||
raise NotImplementedError("SVD algorithm is not implemented")
|
||||
|
||||
result = tf.linalg.svd(operand, full_matrices, compute_uv)
|
||||
if not compute_uv:
|
||||
return result,
|
||||
|
@ -30,6 +30,7 @@ from jax._src.lax.linalg import (
|
||||
qr_p as qr_p,
|
||||
svd as svd,
|
||||
svd_p as svd_p,
|
||||
SvdAlgorithm as SvdAlgorithm,
|
||||
triangular_solve as triangular_solve,
|
||||
triangular_solve_p as triangular_solve_p,
|
||||
tridiagonal as tridiagonal,
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
from functools import partial
|
||||
import importlib
|
||||
import math
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
@ -119,138 +118,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
|
||||
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
|
||||
|
||||
|
||||
def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
full_matrices=True, compute_uv=True):
|
||||
"""Singular value decomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = math.prod(batch_dims)
|
||||
if ir.ComplexType.isinstance(a_type.element_type):
|
||||
singular_vals_type = ir.ComplexType(a_type.element_type).element_type
|
||||
else:
|
||||
singular_vals_type = a_type.element_type
|
||||
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1))
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
|
||||
# NVIDIA's batched Jacobi solver supports a maximum matrix size of 32x32, but
|
||||
# the unbatched solver has no such limit. The unbatched solver appears to
|
||||
# outperform gesvd for small-moderate matrices, e.g., see:
|
||||
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
|
||||
# slide 5.
|
||||
if have_jacobi_solver and m <= 1024 and n <= 1024:
|
||||
# The gesvdjbatched kernel doesn't support "econ" mode. We will use that
|
||||
# kernel only if b > 1 and m <= 32 and n <= 32.
|
||||
econ = not full_matrices and (b <= 1 or m > 32 or n > 32)
|
||||
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
|
||||
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
|
||||
k = min(m, n)
|
||||
matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
_, s, u, v, info, _ = custom_call(
|
||||
f"{platform}solver_gesvdj",
|
||||
result_types=[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
|
||||
ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
|
||||
a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
|
||||
a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
operands=[a],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[matrix_layout],
|
||||
result_layouts=[
|
||||
matrix_layout,
|
||||
vector_layout,
|
||||
matrix_layout,
|
||||
matrix_layout,
|
||||
scalar_layout,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0}).results
|
||||
vt = hlo.transpose(
|
||||
v,
|
||||
dense_int_array(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
|
||||
if not full_matrices and not econ:
|
||||
u = hlo.slice(
|
||||
u,
|
||||
dense_int_array(np.zeros([len(dims)], np.int64)),
|
||||
dense_int_array(np.array(batch_dims + (m, min(m, n)))),
|
||||
dense_int_array(np.ones([len(dims)], np.int64)))
|
||||
vt = hlo.slice(
|
||||
vt,
|
||||
dense_int_array(np.zeros([len(dims)], np.int64)),
|
||||
dense_int_array(np.array(batch_dims + (min(m, n), n))),
|
||||
dense_int_array(np.ones([len(dims)], np.int64)))
|
||||
elif m < n:
|
||||
lwork, opaque = gpu_solver.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
|
||||
k = n if full_matrices else m
|
||||
matrix_layout = (num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))
|
||||
_, s, vt, u, info, _ = custom_call(
|
||||
f"{platform}solver_gesvd",
|
||||
result_types=[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
|
||||
ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
operands=[a],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[matrix_layout],
|
||||
result_layouts=[
|
||||
matrix_layout,
|
||||
vector_layout,
|
||||
matrix_layout,
|
||||
matrix_layout,
|
||||
scalar_layout,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0}).results
|
||||
else:
|
||||
lwork, opaque = gpu_solver.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
|
||||
k = m if full_matrices else n
|
||||
matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
_, s, u, vt, info, _ = custom_call(
|
||||
f"{platform}solver_gesvd",
|
||||
result_types=[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
|
||||
ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
operands=[a],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[matrix_layout],
|
||||
result_layouts=[
|
||||
matrix_layout,
|
||||
vector_layout,
|
||||
matrix_layout,
|
||||
matrix_layout,
|
||||
scalar_layout,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0}).results
|
||||
return s, u, vt, info
|
||||
|
||||
cuda_gesvd = partial(_gesvd_hlo, "cu", _cusolver, True)
|
||||
rocm_gesvd = partial(_gesvd_hlo, "hip", _hipsolver, False)
|
||||
|
||||
|
||||
def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
|
127
jaxlib/lapack.py
127
jaxlib/lapack.py
@ -213,133 +213,6 @@ def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False,
|
||||
return out[:2]
|
||||
|
||||
|
||||
# # ?gesdd: Singular value decomposition
|
||||
|
||||
def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
|
||||
a_shape_vals: tuple[DimensionSize, ...]):
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
assert len(a_shape_vals) >= 2
|
||||
m, n = a_shape_vals[-2:]
|
||||
assert type(m) is int
|
||||
assert type(n) is int
|
||||
batch_dims_vals = a_shape_vals[:-2]
|
||||
num_bd = len(batch_dims_vals)
|
||||
fn_base = prepare_lapack_call(fn_base="gesdd", dtype=dtype)
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
workspace: list[ShapeTypePair]
|
||||
|
||||
# TODO(b/344892332): Remove the old kernel after the compatibility period.
|
||||
if ctx.is_forward_compat():
|
||||
fn = fn_base
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
if dtype == np.float32:
|
||||
singular_vals_type = ir.F32Type.get()
|
||||
lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices)
|
||||
workspace = [
|
||||
([_lapack.gesdd_iwork_size(m, n)], i32_type),
|
||||
([lwork], a_type.element_type),
|
||||
]
|
||||
workspace_layouts = [[0], [0]]
|
||||
elif dtype == np.float64:
|
||||
singular_vals_type = ir.F64Type.get()
|
||||
lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices)
|
||||
workspace = [
|
||||
([_lapack.gesdd_iwork_size(m, n)], i32_type),
|
||||
([lwork], a_type.element_type),
|
||||
]
|
||||
workspace_layouts = [[0], [0]]
|
||||
elif dtype == np.complex64:
|
||||
singular_vals_type = ir.F32Type.get()
|
||||
lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices)
|
||||
workspace = [
|
||||
([_lapack.gesdd_iwork_size(m, n)], i32_type),
|
||||
([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()),
|
||||
([lwork], a_type.element_type),
|
||||
]
|
||||
workspace_layouts = [[0], [0], [0]]
|
||||
elif dtype == np.complex128:
|
||||
singular_vals_type = ir.F64Type.get()
|
||||
lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices)
|
||||
workspace = [
|
||||
([_lapack.gesdd_iwork_size(m, n)], i32_type),
|
||||
([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()),
|
||||
([lwork], a_type.element_type),
|
||||
]
|
||||
workspace_layouts = [[0], [0], [0]]
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
scalar_layout = []
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
shape_type_pairs: Sequence[ShapeTypePair] = [
|
||||
(a_shape_vals, a_type.element_type),
|
||||
(batch_dims_vals + (min(m, n),), singular_vals_type),
|
||||
(batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type),
|
||||
(batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type),
|
||||
(batch_dims_vals, i32_type),
|
||||
] + workspace
|
||||
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
|
||||
return custom_call(
|
||||
fn,
|
||||
result_types=result_types,
|
||||
operands=[hlo_s32(int(full_matrices)), hlo_s32(int(compute_uv)), batch_size_val,
|
||||
hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a],
|
||||
operand_layouts=[scalar_layout] * 6 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
layout,
|
||||
layout,
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
] + workspace_layouts,
|
||||
operand_output_aliases={6: 0},
|
||||
result_shapes=result_shapes
|
||||
).results[1:5]
|
||||
fn = fn_base + "_ffi"
|
||||
mode_attr = _svd_computation_attr(
|
||||
compute_uv=compute_uv, full_matrices=full_matrices
|
||||
)
|
||||
if dtype == np.float32 or dtype == np.complex64:
|
||||
singular_vals_type = ir.F32Type.get()
|
||||
elif dtype == np.float64 or dtype == np.complex128:
|
||||
singular_vals_type = ir.F64Type.get()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
a_elem_type = a_type.element_type
|
||||
shape_type_pairs: Sequence[ShapeTypePair] = [
|
||||
(a_shape_vals, a_elem_type),
|
||||
(batch_dims_vals + (min(m, n),), singular_vals_type),
|
||||
(batch_dims_vals + (m, m if full_matrices else min(m, n)), a_elem_type),
|
||||
(batch_dims_vals + (n if full_matrices else min(m, n), n), a_elem_type),
|
||||
(batch_dims_vals, i32_type),
|
||||
]
|
||||
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
|
||||
return custom_call(
|
||||
fn,
|
||||
result_types=result_types,
|
||||
operands=[a],
|
||||
operand_layouts=[layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
layout,
|
||||
layout,
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
],
|
||||
operand_output_aliases={0: 0},
|
||||
result_shapes=result_shapes,
|
||||
backend_config={
|
||||
"mode": mode_attr,
|
||||
},
|
||||
api_version=4,
|
||||
).results[1:]
|
||||
|
||||
|
||||
# # geev: Nonsymmetric eigendecomposition (eig)
|
||||
|
||||
def geev_hlo(ctx, dtype, input, *,
|
||||
|
@ -47,6 +47,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenb
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
|
||||
@ -135,6 +136,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cuda_lu_cusolver_getrf.data_2024_08_19,
|
||||
cuda_qr_cusolver_geqrf.data_2024_09_26,
|
||||
cuda_eigh_cusolver_syev.data_2024_09_30,
|
||||
cuda_svd_cusolver_gesvd.data_2024_10_08,
|
||||
rocm_qr_hipsolver_geqrf.data_2024_08_05,
|
||||
rocm_eigh_hipsolver_syev.data_2024_08_05,
|
||||
cpu_schur_lapack_gees.data_2023_07_16,
|
||||
@ -166,6 +168,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
# The following require ROCm to test
|
||||
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
|
||||
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi",
|
||||
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
|
||||
})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
self.assertEmpty(not_covered,
|
||||
@ -646,28 +649,48 @@ class CompatTest(bctu.CompatTestBase):
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
c64=np.complex64, c128=np.complex128)[dtype_name]
|
||||
shape = (2, 4, 4)
|
||||
input = jtu.rand_default(self.rng())(shape, dtype)
|
||||
# del input # Input is in the testdata, here for readability
|
||||
def func(input):
|
||||
return lax.linalg.svd(input, full_matrices=True, compute_uv=True)
|
||||
def func(operand):
|
||||
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True)
|
||||
|
||||
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
||||
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
||||
|
||||
info = cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
|
||||
data = self.load_testdata(info)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results,
|
||||
*data.inputs))
|
||||
|
||||
data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results,
|
||||
input))
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results, input))
|
||||
*data.inputs),
|
||||
expect_current_custom_calls=info["custom_call_targets"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}",
|
||||
dtype_name=dtype_name, algorithm_name=algorithm_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128")
|
||||
for algorithm_name in ("qr", "jacobi"))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_gpu_svd_solver_gesvd(self, dtype_name, algorithm_name):
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
def func(operand):
|
||||
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True,
|
||||
algorithm=algorithm)
|
||||
|
||||
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
||||
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
||||
algorithm = dict(qr=lax.linalg.SvdAlgorithm.QR,
|
||||
jacobi=lax.linalg.SvdAlgorithm.JACOBI)[algorithm_name]
|
||||
|
||||
info = cuda_svd_cusolver_gesvd.data_2024_10_08[algorithm_name][dtype_name]
|
||||
data = self.load_testdata(info)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results,
|
||||
*data.inputs))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
|
@ -48,7 +48,6 @@ from jax._src.export import shape_poly
|
||||
from jax._src.export import shape_poly_decision
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -2880,6 +2879,28 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
((2, 3, 4, 5), "b1, b2, m, n"),
|
||||
]
|
||||
],
|
||||
[
|
||||
PolyHarness( # pylint: disable=g-complex-comprehension
|
||||
"svd", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}_{compute_uv=}",
|
||||
lambda x, full_matrices, compute_uv: lax.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv),
|
||||
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices), StaticArg(compute_uv)],
|
||||
polymorphic_shapes=[poly],
|
||||
symbolic_constraints=constraints)
|
||||
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
||||
for compute_uv in [True, False]
|
||||
for full_matrices in ([True, False] if compute_uv else [True])
|
||||
for shape, poly, constraints in [
|
||||
((2, 0, 4), "b, ...", ()),
|
||||
((2, 4, 0), "b, ...", ()),
|
||||
((2, 3, 4, 4), "b1, b2, ...", ()),
|
||||
((2, 3, 4, 5), "b1, b2, ...", ()),
|
||||
((2, 3, 8, 4), "b1, b2, ...", ()),
|
||||
# The constraints listed here are only for the GPU implementation
|
||||
# which selects an algorithm based on the size of the matrix.
|
||||
((5, 4), "m, n", ["n <= m", "m <= 32", "n <= 32"]),
|
||||
((2, 3, 4, 5), "b1, b2, m, n", ["m <= n", "m <= 32", "n <= 32"]),
|
||||
]
|
||||
],
|
||||
[
|
||||
# The random primitive tests, with threefry (both partitionable and
|
||||
# non-partitionable), and unsafe_rbg.
|
||||
@ -3498,23 +3519,6 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
if harness.expect_error == expect_error_associative_scan and jtu.test_device_matches(["tpu"]):
|
||||
harness.expect_error = None
|
||||
|
||||
# Exclude some harnesses that are known to fail for native serialization
|
||||
# Set of harness.group_name:platform that are implemented with custom call
|
||||
custom_call_harnesses = {
|
||||
"vmap_svd:gpu",
|
||||
}
|
||||
name_device_key = f"{harness.group_name}:{jtu.device_under_test()}"
|
||||
if name_device_key in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
|
||||
# This list keeps track of the minimum jaxlib version that supports shape
|
||||
# polymorphism for some new primitives as we add them. This check is
|
||||
# required so that we can still run the test suite with older versions of
|
||||
# jaxlib.
|
||||
version_gated = {}
|
||||
if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version:
|
||||
raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}")
|
||||
|
||||
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
|
||||
|
||||
@ -3553,11 +3557,11 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("JAX implements eig only on CPU.")
|
||||
|
||||
if (harness.group_name == "eigh" and
|
||||
if (harness.group_name in ("eigh", "svd") and
|
||||
not harness.polymorphic_shapes[0].endswith("...") and
|
||||
jtu.test_device_matches(["tpu"])):
|
||||
raise unittest.SkipTest(
|
||||
"Shape polymorphsim for Eigh is only supported for batch dimensions on TPU.")
|
||||
"Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.")
|
||||
|
||||
config_flags = harness.override_jax_config_flags
|
||||
# Update this here rather than in harness object because vmap_random_gamma is derived
|
||||
|
Loading…
x
Reference in New Issue
Block a user