Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix

PiperOrigin-RevId: 733077011
This commit is contained in:
jax authors 2025-03-03 15:11:31 -08:00
commit 07d1cd0290
2 changed files with 26 additions and 8 deletions

View File

@ -1175,7 +1175,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
if not keepdims and col_axis > row_axis:
col_axis -= 1
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
axis=col_axis, keepdims=keepdims, initial=0)
elif ord == -1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
@ -1185,7 +1185,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
if not keepdims and row_axis > col_axis:
row_axis -= 1
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
axis=row_axis, keepdims=keepdims, initial=0)
elif ord == -np.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
@ -1193,14 +1193,13 @@ def norm(x: ArrayLike, ord: int | str | None = None,
axis=row_axis, keepdims=keepdims)
elif ord in ('nuc', 2, -2):
x = jnp.moveaxis(x, axis, (-2, -1))
s = svd(x, compute_uv=False)
if ord == 2:
reducer = reductions.amax
y = reductions.amax(s, axis=-1, initial=0)
elif ord == -2:
reducer = reductions.amin
y = reductions.amin(s, axis=-1)
else:
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
reducer = reductions.sum # type: ignore[assignment]
y = reducer(svd(x, compute_uv=False), axis=-1)
y = reductions.sum(s, axis=-1)
if keepdims:
y = jnp.expand_dims(y, axis)
return y
@ -1652,7 +1651,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, k
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == np.inf:
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims, initial=0)
elif ord == -np.inf:
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:

View File

@ -715,6 +715,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)
@jtu.sample_product(
shape=[(0, 2), (2, 0), (0, 0)],
dtype=float_types + complex_types,
ord=[1, 2, np.inf, 'fro', 'nuc'],
)
def testEmptyMatrixNorm(self, shape, dtype, ord):
x = jnp.zeros(shape, dtype)
norm = jnp.linalg.matrix_norm(x, ord=ord)
self.assertEqual(norm, 0)
@skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0")
@jtu.sample_product(
[
@ -735,6 +745,15 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
@jtu.sample_product(
dtype=float_types + complex_types,
ord=[1, 2, np.inf],
)
def testEmptyVectorNorm(self, dtype, ord):
x = jnp.zeros(0, dtype)
norm = jnp.linalg.vector_norm(x, ord=ord)
self.assertEqual(norm, 0)
# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
@jtu.sample_product(
[