mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25271 from jakevdp:fix-vector-norm
PiperOrigin-RevId: 703178721
This commit is contained in:
commit
a71f9a62e6
@ -1159,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
|
||||
num_axes = len(axis)
|
||||
if num_axes == 1:
|
||||
if ord is None or ord == 2:
|
||||
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == jnp.inf:
|
||||
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
||||
axis=axis, keepdims=keepdims)
|
||||
elif ord == 1:
|
||||
# Numpy has a special case for ord == 1 as an optimization. We don't
|
||||
# really need the optimization (XLA could do it for us), but the Numpy
|
||||
# code has slightly different type promotion semantics, so we need a
|
||||
# special case too.
|
||||
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif isinstance(ord, str):
|
||||
msg = f"Invalid order '{ord}' for vector norm."
|
||||
if ord == "inf":
|
||||
msg += "Use 'jax.numpy.inf' instead."
|
||||
if ord == "-inf":
|
||||
msg += "Use '-jax.numpy.inf' instead."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
abs_x = ufuncs.abs(x)
|
||||
ord_arr = lax_internal._const(abs_x, ord)
|
||||
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
|
||||
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
|
||||
return ufuncs.power(out, ord_inv)
|
||||
return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims)
|
||||
|
||||
elif num_axes == 2:
|
||||
row_axis, col_axis = axis # pytype: disable=bad-unpacking
|
||||
@ -1632,7 +1604,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
|
||||
|
||||
|
||||
@export
|
||||
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
|
||||
def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False,
|
||||
ord: int | str = 2) -> Array:
|
||||
"""Compute the vector norm of a vector or batch of vectors.
|
||||
|
||||
@ -1668,13 +1640,35 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
|
||||
Array([3.7416575, 9.486833 ], dtype=float32)
|
||||
"""
|
||||
check_arraylike('jnp.linalg.vector_norm', x)
|
||||
if axis is None:
|
||||
result = norm(jnp.ravel(x), ord=ord)
|
||||
if keepdims:
|
||||
result = lax.expand_dims(result, range(jnp.ndim(x)))
|
||||
return result
|
||||
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
|
||||
|
||||
if ord is None or ord == 2:
|
||||
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == jnp.inf:
|
||||
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
||||
axis=axis, keepdims=keepdims)
|
||||
elif ord == 1:
|
||||
# Numpy has a special case for ord == 1 as an optimization. We don't
|
||||
# really need the optimization (XLA could do it for us), but the Numpy
|
||||
# code has slightly different type promotion semantics, so we need a
|
||||
# special case too.
|
||||
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif isinstance(ord, str):
|
||||
msg = f"Invalid order '{ord}' for vector norm."
|
||||
if ord == "inf":
|
||||
msg += "Use 'jax.numpy.inf' instead."
|
||||
if ord == "-inf":
|
||||
msg += "Use '-jax.numpy.inf' instead."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
abs_x = ufuncs.abs(x)
|
||||
ord_arr = lax_internal._const(abs_x, ord)
|
||||
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
|
||||
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
|
||||
return ufuncs.power(out, ord_inv)
|
||||
|
||||
@export
|
||||
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
from typing import Iterator
|
||||
from unittest import skipIf
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -54,6 +56,20 @@ def _is_required_cuda_version_satisfied(cuda_version):
|
||||
return int(version.split()[-1]) >= cuda_version
|
||||
|
||||
|
||||
def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]:
|
||||
"""
|
||||
Generate a range of valid axis arguments for a reduction over
|
||||
an array with a given number of dimensions.
|
||||
"""
|
||||
yield from (None, ())
|
||||
if ndim > 0:
|
||||
yield from (0, (-1,))
|
||||
if ndim > 1:
|
||||
yield from (1, (0, 1), (-1, 0))
|
||||
if ndim > 2:
|
||||
yield (-1, 0, 1)
|
||||
|
||||
|
||||
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
|
||||
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
|
||||
if scipy_version >= (1, 17, 0):
|
||||
@ -707,29 +723,25 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
|
||||
@skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0")
|
||||
@jtu.sample_product(
|
||||
shape=[(3,), (3, 4), (2, 3, 4, 5)],
|
||||
[
|
||||
dict(shape=shape, axis=axis)
|
||||
for shape in [(3,), (3, 4), (2, 3, 4, 5)]
|
||||
for axis in _axis_for_ndim(len(shape))
|
||||
],
|
||||
dtype=float_types + complex_types,
|
||||
keepdims=[True, False],
|
||||
axis=[0, None],
|
||||
ord=[1, -1, 2, -2, np.inf, -np.inf],
|
||||
)
|
||||
def testVectorNorm(self, shape, dtype, keepdims, axis, ord):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if jtu.numpy_version() < (2, 0, 0):
|
||||
def np_fn(x, *, ord, keepdims, axis):
|
||||
x = np.asarray(x)
|
||||
if axis is None:
|
||||
result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0)
|
||||
return np.reshape(result, (1,) * x.ndim) if keepdims else result
|
||||
return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis)
|
||||
else:
|
||||
np_fn = np.linalg.vector_norm
|
||||
np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis)
|
||||
np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
|
||||
jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
tol = 1E-3 if jtu.test_device_matches(['tpu']) else None
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
|
||||
@jtu.sample_product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user