mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix jnp.matmul return shape documentation
If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13)
This commit is contained in:
parent
47858c4ac2
commit
cd578d97e8
@ -9076,7 +9076,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
|
||||
|
||||
Returns:
|
||||
array containing the matrix product of the inputs. Shape is ``a.shape[:-1]``
|
||||
if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading
|
||||
if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading
|
||||
dimensions of ``a`` and ``b`` are broadcast together.
|
||||
|
||||
See Also:
|
||||
|
Loading…
x
Reference in New Issue
Block a user