jax.scipy.qr: fix return type for mode='r'

This commit is contained in:
Jake VanderPlas 2022-04-26 11:26:56 -07:00
parent 04b6f15cdb
commit 822b6aad3b
3 changed files with 35 additions and 1 deletions

View File

@ -17,6 +17,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
* {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when
`mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`)
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
that specifies the behavior of out-of-bounds indexing. By default,
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In

View File

@ -176,7 +176,7 @@ def _qr(a, mode, pivoting):
a, = _promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices)
if mode == "r":
return r
return (r,)
return q, r
@_wraps(scipy.linalg.qr)

View File

@ -620,6 +620,38 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# do not check it functionality here
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode),
"shape": shape, "dtype": dtype, "mode": mode}
for shape in [(3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["full", "r", "economic"]))
def testScipyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
sp_func = partial(scipy.linalg.qr, mode=mode)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(jsp_func, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode),
"shape": shape, "dtype": dtype, "mode": mode}
for shape in [(3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["reduced", "r", "full", "complete"]))
def testNumpyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
np_func = partial(np.linalg.qr, mode=mode)
if mode == "full":
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(jnp_func, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_fullmatrices={}".format(
jtu.format_shape_dtype_string(shape, dtype), full_matrices),