mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.qr: fix return type for mode='r'
This commit is contained in:
parent
04b6f15cdb
commit
822b6aad3b
@ -17,6 +17,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* `jax.experimental.maps.mesh` has been deleted.
|
* `jax.experimental.maps.mesh` has been deleted.
|
||||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||||
for more information.
|
for more information.
|
||||||
|
* {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when
|
||||||
|
`mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`)
|
||||||
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
|
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
|
||||||
that specifies the behavior of out-of-bounds indexing. By default,
|
that specifies the behavior of out-of-bounds indexing. By default,
|
||||||
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
||||||
|
@ -176,7 +176,7 @@ def _qr(a, mode, pivoting):
|
|||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
q, r = lax_linalg.qr(a, full_matrices)
|
q, r = lax_linalg.qr(a, full_matrices)
|
||||||
if mode == "r":
|
if mode == "r":
|
||||||
return r
|
return (r,)
|
||||||
return q, r
|
return q, r
|
||||||
|
|
||||||
@_wraps(scipy.linalg.qr)
|
@_wraps(scipy.linalg.qr)
|
||||||
|
@ -620,6 +620,38 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
# do not check it functionality here
|
# do not check it functionality here
|
||||||
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
|
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{"testcase_name": "_shape={}_mode={}".format(
|
||||||
|
jtu.format_shape_dtype_string(shape, dtype), mode),
|
||||||
|
"shape": shape, "dtype": dtype, "mode": mode}
|
||||||
|
for shape in [(3, 4), (3, 3), (4, 3)]
|
||||||
|
for dtype in [np.float32]
|
||||||
|
for mode in ["full", "r", "economic"]))
|
||||||
|
def testScipyQrModes(self, shape, dtype, mode):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
|
||||||
|
sp_func = partial(scipy.linalg.qr, mode=mode)
|
||||||
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||||
|
self._CompileAndCheck(jsp_func, args_maker)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{"testcase_name": "_shape={}_mode={}".format(
|
||||||
|
jtu.format_shape_dtype_string(shape, dtype), mode),
|
||||||
|
"shape": shape, "dtype": dtype, "mode": mode}
|
||||||
|
for shape in [(3, 4), (3, 3), (4, 3)]
|
||||||
|
for dtype in [np.float32]
|
||||||
|
for mode in ["reduced", "r", "full", "complete"]))
|
||||||
|
def testNumpyQrModes(self, shape, dtype, mode):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
|
||||||
|
np_func = partial(np.linalg.qr, mode=mode)
|
||||||
|
if mode == "full":
|
||||||
|
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
|
||||||
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||||
|
self._CompileAndCheck(jnp_func, args_maker)
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_shape={}_fullmatrices={}".format(
|
{"testcase_name": "_shape={}_fullmatrices={}".format(
|
||||||
jtu.format_shape_dtype_string(shape, dtype), full_matrices),
|
jtu.format_shape_dtype_string(shape, dtype), full_matrices),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user