From cd578d97e8367dce90c96705c386df5aaa299988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tor=20Gunnar=20H=C3=B8st=20Houeland?= <887395+houeland@users.noreply.github.com> Date: Sat, 30 Nov 2024 18:55:00 +0000 Subject: [PATCH] Fix jnp.matmul return shape documentation If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13) --- jax/_src/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5f380fad9..a61b1d67f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9076,7 +9076,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, Returns: array containing the matrix product of the inputs. Shape is ``a.shape[:-1]`` - if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading + if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading dimensions of ``a`` and ``b`` are broadcast together. See Also: