diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 30553a360..7f2a5785a 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -274,6 +274,7 @@ namespace; they are listed below. mask_indices matmul matrix_transpose + matvec max maximum mean @@ -428,6 +429,7 @@ namespace; they are listed below. var vdot vecdot + vecmat vectorize vsplit vstack diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3d9940542..3da4aa462 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9168,6 +9168,89 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) +@export +@jit +def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Batched matrix-vector product. + + JAX implementation of :func:`numpy.matvec`. + + Args: + x1: array of shape ``(..., M, N)`` + x2: array of shape ``(..., N)``. Leading dimensions must be broadcast-compatible + with leading dimensions of ``x1``. + + Returns: + An array of shape ``(..., M)`` containing the batched matrix-vector product. + + See also: + - :func:`jax.numpy.linalg.vecdot`: batched vector product. + - :func:`jax.numpy.vecmat`: vector-matrix product. + - :func:`jax.numpy.matmul`: general matrix multiplication. + + Examples: + Simple matrix-vector product: + + >>> x1 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> x2 = jnp.array([7, 8, 9]) + >>> jnp.matvec(x1, x2) + Array([ 50, 122], dtype=int32) + + Batched matrix-vector product: + + >>> x2 = jnp.array([[7, 8, 9], + ... [5, 6, 7]]) + >>> jnp.matvec(x1, x2) + Array([[ 50, 122], + [ 38, 92]], dtype=int32) + """ + util.check_arraylike("matvec", x1, x2) + return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2) + + +@export +@jit +def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Batched conjugate vector-matrix product. + + JAX implementation of :func:`numpy.vecmat`. + + Args: + x1: array of shape ``(..., M)``. + x2: array of shape ``(..., M, N)``. Leading dimensions must be broadcast-compatible + with leading dimensions of ``x1``. + + Returns: + An array of shape ``(..., N)`` containing the batched conjugate vector-matrix product. + + See also: + - :func:`jax.numpy.linalg.vecdot`: batched vector product. + - :func:`jax.numpy.matvec`: matrix-vector product. + - :func:`jax.numpy.matmul`: general matrix multiplication. + + Examples: + Simple vector-matrix product: + + >>> x1 = jnp.array([[1, 2, 3]]) + >>> x2 = jnp.array([[4, 5], + ... [6, 7], + ... [8, 9]]) + >>> jnp.vecmat(x1, x2) + Array([[40, 46]], dtype=int32) + + Batched vector-matrix product: + + >>> x1 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.vecmat(x1, x2) + Array([[ 40, 46], + [ 94, 109]], dtype=int32) + """ + util.check_arraylike("matvec", x1, x2) + return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2) + + @export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( @@ -9244,6 +9327,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, See Also: - :func:`jax.numpy.vdot`: flattened vector product. + - :func:`jax.numpy.vecmat`: vector-matrix product. - :func:`jax.numpy.matmul`: general matrix multiplication. - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 12736c1cd..d0e06e68d 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -176,6 +176,7 @@ from jax._src.numpy.lax_numpy import ( logspace as logspace, mask_indices as mask_indices, matmul as matmul, + matvec as matvec, matrix_transpose as matrix_transpose, meshgrid as meshgrid, moveaxis as moveaxis, @@ -258,6 +259,7 @@ from jax._src.numpy.lax_numpy import ( vander as vander, vdot as vdot, vecdot as vecdot, + vecmat as vecmat, vsplit as vsplit, vstack as vstack, where as where, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 5d357ab1b..26874615a 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -645,6 +645,7 @@ def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ...) -> Array: ... def matrix_transpose(x: ArrayLike, /) -> Array: ... +def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: ... def max(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... @@ -995,6 +996,7 @@ def vdot( def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ...) -> Array: ... +def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: ... def vsplit( ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 24c9dd0ca..7d26b1df8 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -649,6 +649,57 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + @jtu.sample_product( + lhs_batch=broadcast_compatible_shapes, + rhs_batch=broadcast_compatible_shapes, + mat_size=[1, 2, 3], + vec_size=[2, 3, 4], + dtype=number_dtypes, + ) + @jax.default_matmul_precision("float32") + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testMatvec(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype): + rng = jtu.rand_default(self.rng()) + lhs_shape = (*lhs_batch, mat_size, vec_size) + rhs_shape = (*rhs_batch, vec_size) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + jnp_fn = jnp.matvec + @jtu.promote_like_jnp + def np_fn(x, y): + f = (np.vectorize(np.matmul, signature="(m,n),(n)->(m)") + if jtu.numpy_version() < (2, 2, 0) else np.matvec) + return f(x, y).astype(x.dtype) + tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, + np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1} + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + + @jtu.sample_product( + lhs_batch=broadcast_compatible_shapes, + rhs_batch=broadcast_compatible_shapes, + mat_size=[1, 2, 3], + vec_size=[2, 3, 4], + dtype=number_dtypes, + ) + @jax.default_matmul_precision("float32") + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testVecmat(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype): + rng = jtu.rand_default(self.rng()) + lhs_shape = (*lhs_batch, vec_size) + rhs_shape = (*rhs_batch, vec_size, mat_size) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + jnp_fn = jnp.vecmat + @jtu.promote_like_jnp + def np_fn(x, y): + f = (np.vectorize(lambda x, y: np.matmul(np.conj(x), y), + signature="(m),(m,n)->(n)") + if jtu.numpy_version() < (2, 2, 0) else np.vecmat) + return f(x, y).astype(x.dtype) + tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, + np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1} + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) for lhs_shape, rhs_shape, axes in [ @@ -6257,7 +6308,6 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'isnat', 'loadtxt', 'matrix', - 'matvec', 'may_share_memory', 'memmap', 'min_scalar_type', @@ -6283,8 +6333,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'show_runtime', 'test', 'trapz', - 'typename', - 'vecmat'} + 'typename'} # symbols removed in NumPy 2.0 skip |= {'add_docstring',