mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix jax.numpy.linalg.inv with shape polymorphism
This commit is contained in:
parent
8c39d0373a
commit
520980171f
@ -1620,7 +1620,7 @@ def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array
|
||||
conjugate_a=conj)
|
||||
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True,
|
||||
transpose_a=True, conjugate_a=conj)
|
||||
_, ind = lax.sort_key_val(permutation, lax.iota('int32', len(permutation)))
|
||||
_, ind = lax.sort_key_val(permutation, lax.iota('int32', permutation.shape[0]))
|
||||
x = x[ind, :]
|
||||
else:
|
||||
raise ValueError(f"'trans' value must be 0, 1, or 2, got {trans}")
|
||||
|
@ -2769,6 +2769,11 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: jnp.nanquantile(x, .5, axis=0),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("inv", "",
|
||||
lambda x: jnp.linalg.inv(jnp.eye(x.shape[0])),
|
||||
arg_descriptors=[RandArg((3, 3), _f32)],
|
||||
polymorphic_shapes=["b, b, ..."],
|
||||
override_jax_config_flags={"jax_export_ignore_forward_compatibility": True}),
|
||||
[
|
||||
PolyHarness(
|
||||
"qr", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}",
|
||||
|
Loading…
x
Reference in New Issue
Block a user