mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
This commit is contained in:
commit
07d1cd0290
@ -1175,7 +1175,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
if not keepdims and col_axis > row_axis:
|
||||
col_axis -= 1
|
||||
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
|
||||
axis=col_axis, keepdims=keepdims)
|
||||
axis=col_axis, keepdims=keepdims, initial=0)
|
||||
elif ord == -1:
|
||||
if not keepdims and col_axis > row_axis:
|
||||
col_axis -= 1
|
||||
@ -1185,7 +1185,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
axis=row_axis, keepdims=keepdims, initial=0)
|
||||
elif ord == -np.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
@ -1193,14 +1193,13 @@ def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
elif ord in ('nuc', 2, -2):
|
||||
x = jnp.moveaxis(x, axis, (-2, -1))
|
||||
s = svd(x, compute_uv=False)
|
||||
if ord == 2:
|
||||
reducer = reductions.amax
|
||||
y = reductions.amax(s, axis=-1, initial=0)
|
||||
elif ord == -2:
|
||||
reducer = reductions.amin
|
||||
y = reductions.amin(s, axis=-1)
|
||||
else:
|
||||
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
|
||||
reducer = reductions.sum # type: ignore[assignment]
|
||||
y = reducer(svd(x, compute_uv=False), axis=-1)
|
||||
y = reductions.sum(s, axis=-1)
|
||||
if keepdims:
|
||||
y = jnp.expand_dims(y, axis)
|
||||
return y
|
||||
@ -1652,7 +1651,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, k
|
||||
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == np.inf:
|
||||
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims, initial=0)
|
||||
elif ord == -np.inf:
|
||||
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
|
@ -715,6 +715,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(0, 2), (2, 0), (0, 0)],
|
||||
dtype=float_types + complex_types,
|
||||
ord=[1, 2, np.inf, 'fro', 'nuc'],
|
||||
)
|
||||
def testEmptyMatrixNorm(self, shape, dtype, ord):
|
||||
x = jnp.zeros(shape, dtype)
|
||||
norm = jnp.linalg.matrix_norm(x, ord=ord)
|
||||
self.assertEqual(norm, 0)
|
||||
|
||||
@skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0")
|
||||
@jtu.sample_product(
|
||||
[
|
||||
@ -735,6 +745,15 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_types + complex_types,
|
||||
ord=[1, 2, np.inf],
|
||||
)
|
||||
def testEmptyVectorNorm(self, dtype, ord):
|
||||
x = jnp.zeros(0, dtype)
|
||||
norm = jnp.linalg.vector_norm(x, ord=ord)
|
||||
self.assertEqual(norm, 0)
|
||||
|
||||
# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
|
||||
@jtu.sample_product(
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user