mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21106 from jakevdp:linalg-precision
PiperOrigin-RevId: 632217396
This commit is contained in:
commit
1a7a2aa555
@ -35,7 +35,7 @@ from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions, ufuncs
|
||||
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.typing import ArrayLike, Array, DTypeLike
|
||||
|
||||
|
||||
class EighResult(NamedTuple):
|
||||
@ -1612,7 +1612,9 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
|
||||
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
|
||||
|
||||
|
||||
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
|
||||
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""Compute the (batched) vector conjugate dot product of two arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.vecdot`.
|
||||
@ -1622,6 +1624,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
|
||||
x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``,
|
||||
and remaining dimensions must be broadcast-compatible.
|
||||
axis: axis along which to compute the dot product (default: -1)
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
||||
such values indicating precision of ``x1`` and ``x2``.
|
||||
preferred_element_type: either ``None`` (default), which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
|
||||
Returns:
|
||||
array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``.
|
||||
@ -1649,10 +1658,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
|
||||
Array([20, 47], dtype=int32)
|
||||
"""
|
||||
check_arraylike('jnp.linalg.vecdot', x1, x2)
|
||||
return jnp.vecdot(x1, x2, axis=axis)
|
||||
return jnp.vecdot(x1, x2, axis=axis, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""Perform a matrix multiplication.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.matmul`.
|
||||
@ -1662,6 +1674,13 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
|
||||
In the multi-dimensional case, leading dimensions must be broadcast-compatible
|
||||
with the leading dimensions of ``x1``.
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
||||
such values indicating precision of ``x1`` and ``x2``.
|
||||
preferred_element_type: either ``None`` (default), which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
|
||||
Returns:
|
||||
array containing the matrix product of the inputs. Shape is ``x1.shape[:-1]``
|
||||
@ -1699,11 +1718,14 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
[49, 64]], dtype=int32)
|
||||
"""
|
||||
check_arraylike('jnp.linalg.matmul', x1, x2)
|
||||
return jnp.matmul(x1, x2)
|
||||
return jnp.matmul(x1, x2, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
|
||||
axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array:
|
||||
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""Compute the tensor dot product of two N-dimensional arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.tensordot`.
|
||||
@ -1715,6 +1737,13 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
|
||||
sum over the last `k` axes of ``x1`` and the first `k` axes of ``x2``,
|
||||
in order. If a tuple, then ``axes[0]`` specifies the axes of ``x1`` and
|
||||
``axes[1]`` specifies the axes of ``x2``.
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
||||
such values indicating precision of ``x1`` and ``x2``.
|
||||
preferred_element_type: either ``None`` (default), which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
|
||||
Returns:
|
||||
array containing the tensor dot product of the inputs
|
||||
@ -1770,7 +1799,8 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
|
||||
[2, 4, 6]], dtype=int32)
|
||||
"""
|
||||
check_arraylike('jnp.linalg.tensordot', x1, x2)
|
||||
return jnp.tensordot(x1, x2, axes=axes)
|
||||
return jnp.tensordot(x1, x2, axes=axes, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
def svdvals(x: ArrayLike, /) -> Array:
|
||||
|
@ -697,6 +697,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# smoke-test for optional kwargs.
|
||||
jnp_fn = partial(jnp.linalg.vecdot, axis=axis,
|
||||
precision=lax.Precision.HIGHEST,
|
||||
preferred_element_type=dtype)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
|
||||
@jtu.sample_product(
|
||||
[
|
||||
@ -719,6 +725,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# smoke-test for optional kwargs.
|
||||
jnp_fn = partial(jnp.linalg.matmul,
|
||||
precision=lax.Precision.HIGHEST,
|
||||
preferred_element_type=dtype)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# jnp.linalg.tensordot is an alias of jnp.tensordot; do a minimal test here.
|
||||
@jtu.sample_product(
|
||||
[
|
||||
@ -742,6 +754,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
# smoke-test for optional kwargs.
|
||||
jnp_fn = partial(jnp.linalg.tensordot, axes=axes,
|
||||
precision=lax.Precision.HIGHEST,
|
||||
preferred_element_type=dtype)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
|
||||
|
Loading…
x
Reference in New Issue
Block a user