mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used. PiperOrigin-RevId: 662024940
This commit is contained in:
parent
ae5b4284d5
commit
3c014a4c27
@ -960,7 +960,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
|
||||
# schur on CPU
|
||||
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
||||
# # lu on GPU
|
||||
# lu on GPU
|
||||
"cu_lu_pivots_to_permutation",
|
||||
# "cublas_getrf_batched", "cusolver_getrf",
|
||||
# "hipblas_getrf_batched", "hipsolver_getrf",
|
||||
# TODO(b/357034884): This can be added once the mimimum version of jaxlib
|
||||
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
from numpy import array, int32
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_08_08 = dict(
|
||||
testdata_version=1,
|
||||
platform='cuda',
|
||||
custom_call_targets=['cu_lu_pivots_to_permutation'],
|
||||
serialized_date=datetime.date(2024, 8, 8),
|
||||
inputs=(),
|
||||
expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[4, 5, 6, 7, 0, 1, 2, 3],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]],
|
||||
|
||||
[[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),),
|
||||
mlir_module_text=r"""
|
||||
module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||
%0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4)
|
||||
%1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5)
|
||||
%c = stablehlo.constant dense<2> : tensor<i64> loc(#loc6)
|
||||
%c_0 = stablehlo.constant dense<3> : tensor<i64> loc(#loc6)
|
||||
%c_1 = stablehlo.constant dense<4> : tensor<i64> loc(#loc6)
|
||||
%2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6)
|
||||
return %2 : tensor<2x3x8xi32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14)
|
||||
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11)
|
||||
#loc4 = loc("jit(<lambda>)/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1))
|
||||
#loc5 = loc("jit(<lambda>)/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2))
|
||||
#loc6 = loc("jit(<lambda>)/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(<lambda>)/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit(<lambda>)/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit(<lambda>)/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
@ -1159,8 +1159,9 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size):
|
||||
len(batch_dims))
|
||||
if m == 0:
|
||||
return permutation
|
||||
result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32),
|
||||
_lu_pivots_body_fn, (permutation, swaps))
|
||||
upper = np.array(k, np.int32) if is_constant_dim(k) else k
|
||||
result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn,
|
||||
(permutation, swaps))
|
||||
return result
|
||||
|
||||
|
||||
@ -1171,19 +1172,14 @@ def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size):
|
||||
raise ValueError(
|
||||
'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype '
|
||||
'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype))
|
||||
|
||||
pivots_size = pivots.shape[-1]
|
||||
if permutation_size < pivots_size:
|
||||
if not permutation_size >= pivots_size:
|
||||
raise ValueError(
|
||||
'Output permutation size {} has to exceed the trailing dimension of '
|
||||
'the pivots. Got pivots size {}'.format(permutation_size, pivots_size))
|
||||
|
||||
batch_dims = pivots.shape[:-1]
|
||||
permutations = pivots.update(shape=batch_dims + (permutation_size,))
|
||||
return pivots.update(shape=(*pivots.shape[:-1], permutation_size))
|
||||
else:
|
||||
permutations = pivots
|
||||
|
||||
return permutations
|
||||
return pivots
|
||||
|
||||
|
||||
def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
|
||||
@ -1196,7 +1192,14 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
|
||||
|
||||
def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
|
||||
permutation_size):
|
||||
return lowering(pivots, permutation_size=permutation_size)
|
||||
# TODO(danfm): Remove once jaxlib 0.4.32 is the minimum version.
|
||||
if jaxlib_version >= (0, 4, 32):
|
||||
pivots_aval, = ctx.avals_in
|
||||
pivots_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, pivots_aval.shape)
|
||||
kwargs = dict(pivots_shape_vals=pivots_shape_vals)
|
||||
else:
|
||||
kwargs = {}
|
||||
return lowering(pivots, permutation_size=permutation_size, **kwargs)
|
||||
|
||||
|
||||
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
|
||||
|
@ -61,16 +61,17 @@ if _hip_linalg:
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size):
|
||||
def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size,
|
||||
pivots_shape_vals):
|
||||
"""Kernel for the transformation of pivots to permutations on GPU."""
|
||||
typ = ir.RankedTensorType(pivots.type)
|
||||
dims = typ.shape
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
assert typ.element_type == i32_type, typ
|
||||
assert len(pivots_shape_vals) >= 1
|
||||
|
||||
pivots_layout = tuple(range(len(dims) - 1, -1, -1))
|
||||
pivots_layout = tuple(range(len(pivots_shape_vals) - 1, -1, -1))
|
||||
permutations_layout = pivots_layout
|
||||
permutations_dims = (*dims[:-1], permutation_size)
|
||||
permutations_dims = (*pivots_shape_vals[:-1], permutation_size)
|
||||
result_types, result_shapes = mk_result_types_and_shapes(
|
||||
[(permutations_dims, i32_type)])
|
||||
return custom_call(
|
||||
|
@ -44,6 +44,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_l
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
|
||||
@ -124,6 +125,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cpu_qr_lapack_geqrf.data_2023_03_17,
|
||||
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
|
||||
cpu_lu_lapack_getrf.data_2023_06_14,
|
||||
cuda_lu_pivots_to_permutation.data_2024_08_08,
|
||||
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
|
||||
rocm_qr_hipsolver_geqrf.data_2024_08_05,
|
||||
rocm_eigh_hipsolver_syev.data_2024_08_05,
|
||||
@ -342,6 +344,17 @@ class CompatTest(bctu.CompatTestBase):
|
||||
self.run_one_test(func, data, rtol=1e-3,
|
||||
check_results=partial(self.check_eigh_results, operand))
|
||||
|
||||
@staticmethod
|
||||
def lu_pivots_to_permutation_harness(shape):
|
||||
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=np.int32), shape)
|
||||
return lax.linalg.lu_pivots_to_permutation(operand, permutation_size=8)
|
||||
|
||||
def test_cuda_lu_pivots_to_permutation(self):
|
||||
shape = (2, 3, 4)
|
||||
func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape)
|
||||
data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08)
|
||||
self.run_one_test(func, data)
|
||||
|
||||
@staticmethod
|
||||
def qr_harness(shape, dtype):
|
||||
# In order to keep inputs small, we construct the input programmatically
|
||||
|
@ -2731,6 +2731,41 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
((2, 3, 8, 4), "b1, b2, ...", True),
|
||||
]
|
||||
],
|
||||
[
|
||||
PolyHarness(
|
||||
"lu_pivots_to_permutation",
|
||||
f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}",
|
||||
lax.linalg.lu_pivots_to_permutation,
|
||||
arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)],
|
||||
polymorphic_shapes=[poly],
|
||||
symbolic_constraints=constraints,
|
||||
)
|
||||
for shape, poly, permutation_size, constraints in [
|
||||
((4,), None, 8, ()),
|
||||
((2, 3, 4), "b1, b2, ...", 8, ()),
|
||||
((4,), "b", 8, ["b <= 8"]),
|
||||
((2, 3, 4), "b1, b2, b3", 8, ["b3 <= 8"]),
|
||||
]
|
||||
],
|
||||
[
|
||||
# Tracing errors are only thrown when the trailing dimension of pivots
|
||||
# is static. Otherwise, the error is thrown at runtime.
|
||||
PolyHarness(
|
||||
"lu_pivots_to_permutation_error",
|
||||
f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}",
|
||||
lax.linalg.lu_pivots_to_permutation,
|
||||
arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)],
|
||||
polymorphic_shapes=[poly],
|
||||
symbolic_constraints=constraints,
|
||||
expect_error=(ValueError, "Output permutation size"),
|
||||
)
|
||||
for shape, poly, permutation_size, constraints in [
|
||||
((4,), None, 3, ()),
|
||||
((2, 3, 4), "b1, b2, ...", 3, ()),
|
||||
((4,), "b", 8, ["b >= 9"]),
|
||||
((2, 3, 4), "b1, b2, b3", 8, ["b3 >= 9"]),
|
||||
]
|
||||
],
|
||||
[
|
||||
# The random primitive tests, with threefry (both partitionable and
|
||||
# non-partitionable), and unsafe_rbg.
|
||||
|
Loading…
x
Reference in New Issue
Block a user