Add support for shape polymorphism with lu_pivots_to_permutation.

This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
This commit is contained in:
Dan Foreman-Mackey 2024-08-12 03:38:55 -07:00 committed by jax authors
parent ae5b4284d5
commit 3c014a4c27
6 changed files with 124 additions and 16 deletions

View File

@ -960,7 +960,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
# schur on CPU
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
# # lu on GPU
# lu on GPU
"cu_lu_pivots_to_permutation",
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",
# TODO(b/357034884): This can be added once the mimimum version of jaxlib

View File

@ -0,0 +1,55 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from numpy import array, int32
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_08_08 = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['cu_lu_pivots_to_permutation'],
serialized_date=datetime.date(2024, 8, 8),
inputs=(),
expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7],
[4, 5, 6, 7, 0, 1, 2, 3],
[0, 1, 2, 3, 4, 5, 6, 7]],
[[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),),
mlir_module_text=r"""
module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4)
%1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5)
%c = stablehlo.constant dense<2> : tensor<i64> loc(#loc6)
%c_0 = stablehlo.constant dense<3> : tensor<i64> loc(#loc6)
%c_1 = stablehlo.constant dense<4> : tensor<i64> loc(#loc6)
%2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6)
return %2 : tensor<2x3x8xi32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11)
#loc4 = loc("jit(<lambda>)/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1))
#loc5 = loc("jit(<lambda>)/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2))
#loc6 = loc("jit(<lambda>)/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(<lambda>)/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit(<lambda>)/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit(<lambda>)/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00",
xla_call_module_version=9,
nr_devices=1,
) # End paste

View File

@ -1159,8 +1159,9 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size):
len(batch_dims))
if m == 0:
return permutation
result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32),
_lu_pivots_body_fn, (permutation, swaps))
upper = np.array(k, np.int32) if is_constant_dim(k) else k
result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn,
(permutation, swaps))
return result
@ -1171,19 +1172,14 @@ def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size):
raise ValueError(
'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype '
'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype))
pivots_size = pivots.shape[-1]
if permutation_size < pivots_size:
if not permutation_size >= pivots_size:
raise ValueError(
'Output permutation size {} has to exceed the trailing dimension of '
'the pivots. Got pivots size {}'.format(permutation_size, pivots_size))
batch_dims = pivots.shape[:-1]
permutations = pivots.update(shape=batch_dims + (permutation_size,))
return pivots.update(shape=(*pivots.shape[:-1], permutation_size))
else:
permutations = pivots
return permutations
return pivots
def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
@ -1196,7 +1192,14 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
permutation_size):
return lowering(pivots, permutation_size=permutation_size)
# TODO(danfm): Remove once jaxlib 0.4.32 is the minimum version.
if jaxlib_version >= (0, 4, 32):
pivots_aval, = ctx.avals_in
pivots_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, pivots_aval.shape)
kwargs = dict(pivots_shape_vals=pivots_shape_vals)
else:
kwargs = {}
return lowering(pivots, permutation_size=permutation_size, **kwargs)
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')

View File

@ -61,16 +61,17 @@ if _hip_linalg:
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size):
def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size,
pivots_shape_vals):
"""Kernel for the transformation of pivots to permutations on GPU."""
typ = ir.RankedTensorType(pivots.type)
dims = typ.shape
i32_type = ir.IntegerType.get_signless(32)
assert typ.element_type == i32_type, typ
assert len(pivots_shape_vals) >= 1
pivots_layout = tuple(range(len(dims) - 1, -1, -1))
pivots_layout = tuple(range(len(pivots_shape_vals) - 1, -1, -1))
permutations_layout = pivots_layout
permutations_dims = (*dims[:-1], permutation_size)
permutations_dims = (*pivots_shape_vals[:-1], permutation_size)
result_types, result_shapes = mk_result_types_and_shapes(
[(permutations_dims, i32_type)])
return custom_call(

View File

@ -44,6 +44,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_l
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
@ -124,6 +125,7 @@ class CompatTest(bctu.CompatTestBase):
cpu_qr_lapack_geqrf.data_2023_03_17,
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
cpu_lu_lapack_getrf.data_2023_06_14,
cuda_lu_pivots_to_permutation.data_2024_08_08,
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
rocm_qr_hipsolver_geqrf.data_2024_08_05,
rocm_eigh_hipsolver_syev.data_2024_08_05,
@ -342,6 +344,17 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=1e-3,
check_results=partial(self.check_eigh_results, operand))
@staticmethod
def lu_pivots_to_permutation_harness(shape):
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=np.int32), shape)
return lax.linalg.lu_pivots_to_permutation(operand, permutation_size=8)
def test_cuda_lu_pivots_to_permutation(self):
shape = (2, 3, 4)
func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape)
data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08)
self.run_one_test(func, data)
@staticmethod
def qr_harness(shape, dtype):
# In order to keep inputs small, we construct the input programmatically

View File

@ -2731,6 +2731,41 @@ _POLY_SHAPE_TEST_HARNESSES = [
((2, 3, 8, 4), "b1, b2, ...", True),
]
],
[
PolyHarness(
"lu_pivots_to_permutation",
f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}",
lax.linalg.lu_pivots_to_permutation,
arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)],
polymorphic_shapes=[poly],
symbolic_constraints=constraints,
)
for shape, poly, permutation_size, constraints in [
((4,), None, 8, ()),
((2, 3, 4), "b1, b2, ...", 8, ()),
((4,), "b", 8, ["b <= 8"]),
((2, 3, 4), "b1, b2, b3", 8, ["b3 <= 8"]),
]
],
[
# Tracing errors are only thrown when the trailing dimension of pivots
# is static. Otherwise, the error is thrown at runtime.
PolyHarness(
"lu_pivots_to_permutation_error",
f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}",
lax.linalg.lu_pivots_to_permutation,
arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)],
polymorphic_shapes=[poly],
symbolic_constraints=constraints,
expect_error=(ValueError, "Output permutation size"),
)
for shape, poly, permutation_size, constraints in [
((4,), None, 3, ()),
((2, 3, 4), "b1, b2, ...", 3, ()),
((4,), "b", 8, ["b >= 9"]),
((2, 3, 4), "b1, b2, b3", 8, ["b3 >= 9"]),
]
],
[
# The random primitive tests, with threefry (both partitionable and
# non-partitionable), and unsafe_rbg.