mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better docs for jnp.cross
This commit is contained in:
parent
ca2d1584f8
commit
a1140e9246
@ -9511,10 +9511,82 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
|
||||
a, b = util.promote_dtypes(a, b)
|
||||
return ravel(a)[:, None] * ravel(b)[None, :]
|
||||
|
||||
@util.implements(np.cross)
|
||||
|
||||
@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis'))
|
||||
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
|
||||
axis: int | None = None):
|
||||
r"""Compute the (batched) cross product of two arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.cross`.
|
||||
|
||||
This computes the 2-dimensional or 3-dimensional cross product,
|
||||
|
||||
.. math::
|
||||
|
||||
c = a \times b
|
||||
|
||||
In 3 dimensions, ``c`` is a length-3 array. In 2 dimensions, ``c`` is
|
||||
a scalar.
|
||||
|
||||
Args:
|
||||
a: N-dimensional array. ``a.shape[axisa]`` indicates the dimension of
|
||||
the cross product, and must be 2 or 3.
|
||||
b: N-dimensional array. Must have ``b.shape[axisb] == a.shape[axisb]``,
|
||||
and other dimensions of ``a`` and ``b`` must be broadcast compatible.
|
||||
axisa: specicy the axis of ``a`` along which to compute the cross product.
|
||||
axisb: specicy the axis of ``b`` along which to compute the cross product.
|
||||
axisc: specicy the axis of ``c`` along which the cross product result
|
||||
will be stored.
|
||||
axis: if specified, this overrides ``axisa``, ``axisb``, and ``axisc``
|
||||
with a single value.
|
||||
|
||||
Returns:
|
||||
The array ``c`` containing the (batched) cross product of ``a`` and ``b``
|
||||
along the specified axes.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.linalg.cross`: an array API compatible function for
|
||||
computing cross products over 3-vectors.
|
||||
|
||||
Examples:
|
||||
A 2-dimensional cross product returns a scalar:
|
||||
|
||||
>>> a = jnp.array([1, 2])
|
||||
>>> b = jnp.array([3, 4])
|
||||
>>> jnp.cross(a, b)
|
||||
Array(-2, dtype=int32)
|
||||
|
||||
A 3-dimensional cross product returns a length-3 vector:
|
||||
|
||||
>>> a = jnp.array([1, 2, 3])
|
||||
>>> b = jnp.array([4, 5, 6])
|
||||
>>> jnp.cross(a, b)
|
||||
Array([-3, 6, -3], dtype=int32)
|
||||
|
||||
With multi-dimensional inputs, the cross-product is computed along
|
||||
the last axis by default. Here's a batched 3-dimensional cross
|
||||
product, operating on the rows of the inputs:
|
||||
|
||||
>>> a = jnp.array([[1, 2, 3],
|
||||
... [3, 4, 3]])
|
||||
>>> b = jnp.array([[2, 3, 2],
|
||||
... [4, 5, 6]])
|
||||
>>> jnp.cross(a, b)
|
||||
Array([[-5, 4, -1],
|
||||
[ 9, -6, -1]], dtype=int32)
|
||||
|
||||
Specifying axis=0 makes this a batched 2-dimensional cross product,
|
||||
operating on the columns of the inputs:
|
||||
|
||||
>>> jnp.cross(a, b, axis=0)
|
||||
Array([-2, -2, 12], dtype=int32)
|
||||
|
||||
Equivalently, we can independently specify the axis of the inputs ``a``
|
||||
and ``b`` and the output ``c``:
|
||||
|
||||
>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
|
||||
Array([-2, -2, 12], dtype=int32)
|
||||
"""
|
||||
# TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here.
|
||||
util.check_arraylike("cross", a, b)
|
||||
if axis is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user