Merge pull request #19252 from jakevdp:fix-vecdot

PiperOrigin-RevId: 596762791
This commit is contained in:
jax authors 2024-01-08 18:54:07 -08:00
commit 6a99e38a82
2 changed files with 27 additions and 15 deletions

View File

@ -744,15 +744,11 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
"""Computes the (vector) dot product of two arrays."""
check_arraylike("jnp.linalg.vecdot", x1, x2)
x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2)
rank = max(x1_arr.ndim, x2_arr.ndim)
x1_arr = jax.lax.broadcast_to_rank(x1_arr, rank)
x2_arr = jax.lax.broadcast_to_rank(x2_arr, rank)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
raise ValueError("x1 and x2 must have the same size along specified axis.")
raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}")
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
# TODO(jakevdp): call lax.dot_general directly
return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0]
return jax.numpy.vectorize(jnp.vdot, signature="(n),(n)->()")(x1_arr, x2_arr)
@_wraps(getattr(np.linalg, "matmul", None))

View File

@ -43,7 +43,7 @@ scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
T = lambda x: np.swapaxes(x, -1, -2)
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
int_types = jtu.dtypes.all_integer
@ -654,26 +654,42 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fn, args_maker)
@jtu.sample_product(
[dict(lhs_shape=(2, 3, 4), rhs_shape=(1, 4), axis=-1),
dict(lhs_shape=(2, 3, 4), rhs_shape=(2, 1, 1), axis=0),
dict(lhs_shape=(2, 3, 4), rhs_shape=(3, 4), axis=1)],
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
axis_size=[2, 4],
axis=range(-2, 2),
dtype=float_types + complex_types,
)
def testVecDot(self, lhs_shape, rhs_shape, axis, dtype):
@jax.default_matmul_precision("float32")
def testVecDot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
# Construct vecdot-compatible shapes.
size = min(len(lhs_batch), len(rhs_batch))
axis = int(np.clip(axis, -size - 1, size))
if axis >= 0:
lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:])
rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:])
else:
laxis = axis + len(lhs_batch) + 1
lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:])
raxis = axis + len(rhs_batch) + 1
rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:])
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, y, axis=axis):
x, y = np.broadcast_arrays(x, y)
x = np.moveaxis(x, axis, -1)
y = np.moveaxis(y, axis, -1)
return np.matmul(x[..., None, :], y[..., None])[..., 0, 0]
x, y = np.broadcast_arrays(x, y)
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
else:
np_fn = partial(np.linalg.vecdot, axis=axis)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
@jtu.sample_product(