mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.qr: fix return type for mode='r'
This commit is contained in:
parent
04b6f15cdb
commit
822b6aad3b
@ -17,6 +17,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* `jax.experimental.maps.mesh` has been deleted.
|
||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||
for more information.
|
||||
* {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when
|
||||
`mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`)
|
||||
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
|
||||
that specifies the behavior of out-of-bounds indexing. By default,
|
||||
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
||||
|
@ -176,7 +176,7 @@ def _qr(a, mode, pivoting):
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
q, r = lax_linalg.qr(a, full_matrices)
|
||||
if mode == "r":
|
||||
return r
|
||||
return (r,)
|
||||
return q, r
|
||||
|
||||
@_wraps(scipy.linalg.qr)
|
||||
|
@ -620,6 +620,38 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
# do not check it functionality here
|
||||
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_mode={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), mode),
|
||||
"shape": shape, "dtype": dtype, "mode": mode}
|
||||
for shape in [(3, 4), (3, 3), (4, 3)]
|
||||
for dtype in [np.float32]
|
||||
for mode in ["full", "r", "economic"]))
|
||||
def testScipyQrModes(self, shape, dtype, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
|
||||
sp_func = partial(scipy.linalg.qr, mode=mode)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||
self._CompileAndCheck(jsp_func, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_mode={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), mode),
|
||||
"shape": shape, "dtype": dtype, "mode": mode}
|
||||
for shape in [(3, 4), (3, 3), (4, 3)]
|
||||
for dtype in [np.float32]
|
||||
for mode in ["reduced", "r", "full", "complete"]))
|
||||
def testNumpyQrModes(self, shape, dtype, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
|
||||
np_func = partial(np.linalg.qr, mode=mode)
|
||||
if mode == "full":
|
||||
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||
self._CompileAndCheck(jnp_func, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_fullmatrices={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), full_matrices),
|
||||
|
Loading…
x
Reference in New Issue
Block a user