mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19252 from jakevdp:fix-vecdot
PiperOrigin-RevId: 596762791
This commit is contained in:
commit
6a99e38a82
@ -744,15 +744,11 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
|
||||
"""Computes the (vector) dot product of two arrays."""
|
||||
check_arraylike("jnp.linalg.vecdot", x1, x2)
|
||||
x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2)
|
||||
rank = max(x1_arr.ndim, x2_arr.ndim)
|
||||
x1_arr = jax.lax.broadcast_to_rank(x1_arr, rank)
|
||||
x2_arr = jax.lax.broadcast_to_rank(x2_arr, rank)
|
||||
if x1_arr.shape[axis] != x2_arr.shape[axis]:
|
||||
raise ValueError("x1 and x2 must have the same size along specified axis.")
|
||||
raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}")
|
||||
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
|
||||
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
|
||||
# TODO(jakevdp): call lax.dot_general directly
|
||||
return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0]
|
||||
return jax.numpy.vectorize(jnp.vdot, signature="(n),(n)->()")(x1_arr, x2_arr)
|
||||
|
||||
|
||||
@_wraps(getattr(np.linalg, "matmul", None))
|
||||
|
@ -43,7 +43,7 @@ scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
||||
|
||||
T = lambda x: np.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
|
||||
float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
int_types = jtu.dtypes.all_integer
|
||||
@ -654,26 +654,42 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=(2, 3, 4), rhs_shape=(1, 4), axis=-1),
|
||||
dict(lhs_shape=(2, 3, 4), rhs_shape=(2, 1, 1), axis=0),
|
||||
dict(lhs_shape=(2, 3, 4), rhs_shape=(3, 4), axis=1)],
|
||||
lhs_batch=broadcast_compatible_shapes,
|
||||
rhs_batch=broadcast_compatible_shapes,
|
||||
axis_size=[2, 4],
|
||||
axis=range(-2, 2),
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testVecDot(self, lhs_shape, rhs_shape, axis, dtype):
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testVecDot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
|
||||
# Construct vecdot-compatible shapes.
|
||||
size = min(len(lhs_batch), len(rhs_batch))
|
||||
axis = int(np.clip(axis, -size - 1, size))
|
||||
if axis >= 0:
|
||||
lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:])
|
||||
rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:])
|
||||
else:
|
||||
laxis = axis + len(lhs_batch) + 1
|
||||
lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:])
|
||||
raxis = axis + len(rhs_batch) + 1
|
||||
rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:])
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
if jtu.numpy_version() < (2, 0, 0):
|
||||
def np_fn(x, y, axis=axis):
|
||||
x, y = np.broadcast_arrays(x, y)
|
||||
x = np.moveaxis(x, axis, -1)
|
||||
y = np.moveaxis(y, axis, -1)
|
||||
return np.matmul(x[..., None, :], y[..., None])[..., 0, 0]
|
||||
x, y = np.broadcast_arrays(x, y)
|
||||
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
|
||||
else:
|
||||
np_fn = partial(np.linalg.vecdot, axis=axis)
|
||||
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
|
||||
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
|
||||
np.complex64: 1E-3, np.complex128: 1e-12}
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
|
||||
@jtu.sample_product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user