Fix jnp.matmul return shape documentation

If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13)
This commit is contained in:
Tor Gunnar Høst Houeland 2024-11-30 18:55:00 +00:00 committed by GitHub
parent 47858c4ac2
commit cd578d97e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9076,7 +9076,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
Returns:
array containing the matrix product of the inputs. Shape is ``a.shape[:-1]``
if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading
if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading
dimensions of ``a`` and ``b`` are broadcast together.
See Also: