Fix jax.numpy.linalg.inv with shape polymorphism

This commit is contained in:
tchatow 2024-09-16 12:05:43 -04:00 committed by Tyler Chatow
parent 8c39d0373a
commit 520980171f
2 changed files with 6 additions and 1 deletions

View File

@ -1620,7 +1620,7 @@ def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array
conjugate_a=conj)
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True,
transpose_a=True, conjugate_a=conj)
_, ind = lax.sort_key_val(permutation, lax.iota('int32', len(permutation)))
_, ind = lax.sort_key_val(permutation, lax.iota('int32', permutation.shape[0]))
x = x[ind, :]
else:
raise ValueError(f"'trans' value must be 0, 1, or 2, got {trans}")

View File

@ -2769,6 +2769,11 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: jnp.nanquantile(x, .5, axis=0),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("inv", "",
lambda x: jnp.linalg.inv(jnp.eye(x.shape[0])),
arg_descriptors=[RandArg((3, 3), _f32)],
polymorphic_shapes=["b, b, ..."],
override_jax_config_flags={"jax_export_ignore_forward_compatibility": True}),
[
PolyHarness(
"qr", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}",