mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.numpy: implement matvec & vecmat
This commit is contained in:
parent
2ff90382d2
commit
f6d58761d1
@ -274,6 +274,7 @@ namespace; they are listed below.
|
||||
mask_indices
|
||||
matmul
|
||||
matrix_transpose
|
||||
matvec
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
@ -428,6 +429,7 @@ namespace; they are listed below.
|
||||
var
|
||||
vdot
|
||||
vecdot
|
||||
vecmat
|
||||
vectorize
|
||||
vsplit
|
||||
vstack
|
||||
|
@ -9168,6 +9168,89 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
|
||||
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
|
||||
|
||||
|
||||
@export
|
||||
@jit
|
||||
def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
"""Batched matrix-vector product.
|
||||
|
||||
JAX implementation of :func:`numpy.matvec`.
|
||||
|
||||
Args:
|
||||
x1: array of shape ``(..., M, N)``
|
||||
x2: array of shape ``(..., N)``. Leading dimensions must be broadcast-compatible
|
||||
with leading dimensions of ``x1``.
|
||||
|
||||
Returns:
|
||||
An array of shape ``(..., M)`` containing the batched matrix-vector product.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
|
||||
- :func:`jax.numpy.vecmat`: vector-matrix product.
|
||||
- :func:`jax.numpy.matmul`: general matrix multiplication.
|
||||
|
||||
Examples:
|
||||
Simple matrix-vector product:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> x2 = jnp.array([7, 8, 9])
|
||||
>>> jnp.matvec(x1, x2)
|
||||
Array([ 50, 122], dtype=int32)
|
||||
|
||||
Batched matrix-vector product:
|
||||
|
||||
>>> x2 = jnp.array([[7, 8, 9],
|
||||
... [5, 6, 7]])
|
||||
>>> jnp.matvec(x1, x2)
|
||||
Array([[ 50, 122],
|
||||
[ 38, 92]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("matvec", x1, x2)
|
||||
return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2)
|
||||
|
||||
|
||||
@export
|
||||
@jit
|
||||
def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
"""Batched conjugate vector-matrix product.
|
||||
|
||||
JAX implementation of :func:`numpy.vecmat`.
|
||||
|
||||
Args:
|
||||
x1: array of shape ``(..., M)``.
|
||||
x2: array of shape ``(..., M, N)``. Leading dimensions must be broadcast-compatible
|
||||
with leading dimensions of ``x1``.
|
||||
|
||||
Returns:
|
||||
An array of shape ``(..., N)`` containing the batched conjugate vector-matrix product.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
|
||||
- :func:`jax.numpy.matvec`: matrix-vector product.
|
||||
- :func:`jax.numpy.matmul`: general matrix multiplication.
|
||||
|
||||
Examples:
|
||||
Simple vector-matrix product:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 3]])
|
||||
>>> x2 = jnp.array([[4, 5],
|
||||
... [6, 7],
|
||||
... [8, 9]])
|
||||
>>> jnp.vecmat(x1, x2)
|
||||
Array([[40, 46]], dtype=int32)
|
||||
|
||||
Batched vector-matrix product:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> jnp.vecmat(x1, x2)
|
||||
Array([[ 40, 46],
|
||||
[ 94, 109]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("matvec", x1, x2)
|
||||
return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2)
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
|
||||
def vdot(
|
||||
@ -9244,6 +9327,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.vdot`: flattened vector product.
|
||||
- :func:`jax.numpy.vecmat`: vector-matrix product.
|
||||
- :func:`jax.numpy.matmul`: general matrix multiplication.
|
||||
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.
|
||||
|
||||
|
@ -176,6 +176,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
logspace as logspace,
|
||||
mask_indices as mask_indices,
|
||||
matmul as matmul,
|
||||
matvec as matvec,
|
||||
matrix_transpose as matrix_transpose,
|
||||
meshgrid as meshgrid,
|
||||
moveaxis as moveaxis,
|
||||
@ -258,6 +259,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
vander as vander,
|
||||
vdot as vdot,
|
||||
vecdot as vecdot,
|
||||
vecmat as vecmat,
|
||||
vsplit as vsplit,
|
||||
vstack as vstack,
|
||||
where as where,
|
||||
|
@ -645,6 +645,7 @@ def matmul(
|
||||
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
|
||||
preferred_element_type: DTypeLike | None = ...) -> Array: ...
|
||||
def matrix_transpose(x: ArrayLike, /) -> Array: ...
|
||||
def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: ...
|
||||
def max(a: ArrayLike, axis: _Axis = ..., out: None = ...,
|
||||
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
|
||||
where: ArrayLike | None = ...) -> Array: ...
|
||||
@ -995,6 +996,7 @@ def vdot(
|
||||
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ...,
|
||||
precision: PrecisionLike = ...,
|
||||
preferred_element_type: DTypeLike | None = ...) -> Array: ...
|
||||
def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: ...
|
||||
def vsplit(
|
||||
ary: ArrayLike, indices_or_sections: int | ArrayLike
|
||||
) -> list[Array]: ...
|
||||
|
@ -649,6 +649,57 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
lhs_batch=broadcast_compatible_shapes,
|
||||
rhs_batch=broadcast_compatible_shapes,
|
||||
mat_size=[1, 2, 3],
|
||||
vec_size=[2, 3, 4],
|
||||
dtype=number_dtypes,
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testMatvec(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs_shape = (*lhs_batch, mat_size, vec_size)
|
||||
rhs_shape = (*rhs_batch, vec_size)
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
jnp_fn = jnp.matvec
|
||||
@jtu.promote_like_jnp
|
||||
def np_fn(x, y):
|
||||
f = (np.vectorize(np.matmul, signature="(m,n),(n)->(m)")
|
||||
if jtu.numpy_version() < (2, 2, 0) else np.matvec)
|
||||
return f(x, y).astype(x.dtype)
|
||||
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
|
||||
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
lhs_batch=broadcast_compatible_shapes,
|
||||
rhs_batch=broadcast_compatible_shapes,
|
||||
mat_size=[1, 2, 3],
|
||||
vec_size=[2, 3, 4],
|
||||
dtype=number_dtypes,
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testVecmat(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs_shape = (*lhs_batch, vec_size)
|
||||
rhs_shape = (*rhs_batch, vec_size, mat_size)
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
jnp_fn = jnp.vecmat
|
||||
@jtu.promote_like_jnp
|
||||
def np_fn(x, y):
|
||||
f = (np.vectorize(lambda x, y: np.matmul(np.conj(x), y),
|
||||
signature="(m),(m,n)->(n)")
|
||||
if jtu.numpy_version() < (2, 2, 0) else np.vecmat)
|
||||
return f(x, y).astype(x.dtype)
|
||||
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
|
||||
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
|
||||
for lhs_shape, rhs_shape, axes in [
|
||||
@ -6257,7 +6308,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'isnat',
|
||||
'loadtxt',
|
||||
'matrix',
|
||||
'matvec',
|
||||
'may_share_memory',
|
||||
'memmap',
|
||||
'min_scalar_type',
|
||||
@ -6283,8 +6333,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'show_runtime',
|
||||
'test',
|
||||
'trapz',
|
||||
'typename',
|
||||
'vecmat'}
|
||||
'typename'}
|
||||
|
||||
# symbols removed in NumPy 2.0
|
||||
skip |= {'add_docstring',
|
||||
|
Loading…
x
Reference in New Issue
Block a user