mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Call _check_arraylike for jnp.linalg & jnp.fft functions
This commit is contained in:
parent
32a0ea80ef
commit
2416d15435
@ -21,6 +21,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Breaking Changes
|
* Breaking Changes
|
||||||
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
|
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
|
||||||
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)
|
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)
|
||||||
|
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
|
||||||
|
require inputs to be array-like: i.e. lists and tuples cannot be used in place
|
||||||
|
of arrays. Part of {jax-issue}`#7737`.
|
||||||
|
|
||||||
## jaxlib 0.3.24
|
## jaxlib 0.3.24
|
||||||
* Changes
|
* Changes
|
||||||
|
@ -21,7 +21,7 @@ from jax import dtypes
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax._src.util import safe_zip
|
from jax._src.util import safe_zip
|
||||||
from jax._src.numpy.util import _wraps
|
from jax._src.numpy.util import _check_arraylike, _wraps
|
||||||
from jax._src.numpy import lax_numpy as jnp
|
from jax._src.numpy import lax_numpy as jnp
|
||||||
from jax._src.typing import Array, ArrayLike
|
from jax._src.typing import Array, ArrayLike
|
||||||
|
|
||||||
@ -42,8 +42,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
|
|||||||
s: Optional[Shape], axes: Optional[Sequence[int]],
|
s: Optional[Shape], axes: Optional[Sequence[int]],
|
||||||
norm: Optional[str]) -> Array:
|
norm: Optional[str]) -> Array:
|
||||||
full_name = "jax.numpy.fft." + func_name
|
full_name = "jax.numpy.fft." + func_name
|
||||||
|
_check_arraylike(full_name, a)
|
||||||
# TODO(jakevdp): call check_arraylike
|
|
||||||
arr = jnp.asarray(a)
|
arr = jnp.asarray(a)
|
||||||
|
|
||||||
if s is not None:
|
if s is not None:
|
||||||
@ -285,6 +284,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array:
|
|||||||
|
|
||||||
@_wraps(np.fft.fftshift)
|
@_wraps(np.fft.fftshift)
|
||||||
def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
|
def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
|
||||||
|
_check_arraylike("fftshift", x)
|
||||||
x = jnp.asarray(x)
|
x = jnp.asarray(x)
|
||||||
shift: Union[int, Sequence[int]]
|
shift: Union[int, Sequence[int]]
|
||||||
if axes is None:
|
if axes is None:
|
||||||
@ -300,6 +300,7 @@ def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Arra
|
|||||||
|
|
||||||
@_wraps(np.fft.ifftshift)
|
@_wraps(np.fft.ifftshift)
|
||||||
def ifftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
|
def ifftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
|
||||||
|
_check_arraylike("ifftshift", x)
|
||||||
x = jnp.asarray(x)
|
x = jnp.asarray(x)
|
||||||
shift: Union[int, Sequence[int]]
|
shift: Union[int, Sequence[int]]
|
||||||
if axes is None:
|
if axes is None:
|
||||||
|
@ -28,7 +28,7 @@ from jax import lax
|
|||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
from jax._src.lax import linalg as lax_linalg
|
from jax._src.lax import linalg as lax_linalg
|
||||||
from jax._src.numpy import lax_numpy as jnp
|
from jax._src.numpy import lax_numpy as jnp
|
||||||
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
|
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _check_arraylike
|
||||||
from jax._src.util import canonicalize_axis
|
from jax._src.util import canonicalize_axis
|
||||||
from jax._src.typing import ArrayLike, Array
|
from jax._src.typing import ArrayLike, Array
|
||||||
|
|
||||||
@ -44,6 +44,7 @@ def _H(x: ArrayLike) -> Array:
|
|||||||
@_wraps(np.linalg.cholesky)
|
@_wraps(np.linalg.cholesky)
|
||||||
@jit
|
@jit
|
||||||
def cholesky(a: ArrayLike) -> Array:
|
def cholesky(a: ArrayLike) -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.cholesky", a)
|
||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
return lax_linalg.cholesky(a)
|
return lax_linalg.cholesky(a)
|
||||||
|
|
||||||
@ -67,6 +68,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
|||||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
|
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
|
||||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||||
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
|
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
|
||||||
|
_check_arraylike("jnp.linalg.svd", a)
|
||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
if hermitian:
|
if hermitian:
|
||||||
w, v = lax_linalg.eigh(a)
|
w, v = lax_linalg.eigh(a)
|
||||||
@ -90,7 +92,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
|||||||
@_wraps(np.linalg.matrix_power)
|
@_wraps(np.linalg.matrix_power)
|
||||||
@partial(jit, static_argnames=('n',))
|
@partial(jit, static_argnames=('n',))
|
||||||
def matrix_power(a: ArrayLike, n: int) -> Array:
|
def matrix_power(a: ArrayLike, n: int) -> Array:
|
||||||
# TODO(jakevdp): call _check_arraylike
|
_check_arraylike("jnp.linalg.matrix_power", a)
|
||||||
arr, = _promote_dtypes_inexact(jnp.asarray(a))
|
arr, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
|
|
||||||
if arr.ndim < 2:
|
if arr.ndim < 2:
|
||||||
@ -129,6 +131,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
|
|||||||
@_wraps(np.linalg.matrix_rank)
|
@_wraps(np.linalg.matrix_rank)
|
||||||
@jit
|
@jit
|
||||||
def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
|
def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.matrix_rank", M)
|
||||||
M, = _promote_dtypes_inexact(jnp.asarray(M))
|
M, = _promote_dtypes_inexact(jnp.asarray(M))
|
||||||
if M.ndim < 2:
|
if M.ndim < 2:
|
||||||
return jnp.any(M != 0).astype(jnp.int32)
|
return jnp.any(M != 0).astype(jnp.int32)
|
||||||
@ -191,6 +194,7 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
|
|||||||
"""))
|
"""))
|
||||||
@partial(jit, static_argnames=('method',))
|
@partial(jit, static_argnames=('method',))
|
||||||
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
|
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
|
||||||
|
_check_arraylike("jnp.linalg.slogdet", a)
|
||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
a_shape = jnp.shape(a)
|
a_shape = jnp.shape(a)
|
||||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||||
@ -329,6 +333,7 @@ def _det_3x3(a: Array) -> Array:
|
|||||||
@_wraps(np.linalg.det)
|
@_wraps(np.linalg.det)
|
||||||
@jit
|
@jit
|
||||||
def det(a: ArrayLike) -> Array:
|
def det(a: ArrayLike) -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.det", a)
|
||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
a_shape = jnp.shape(a)
|
a_shape = jnp.shape(a)
|
||||||
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
|
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
|
||||||
@ -361,6 +366,7 @@ backend. However eigendecomposition for symmetric/Hermitian matrices is
|
|||||||
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
|
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
|
||||||
""")
|
""")
|
||||||
def eig(a: ArrayLike) -> Tuple[Array, Array]:
|
def eig(a: ArrayLike) -> Tuple[Array, Array]:
|
||||||
|
_check_arraylike("jnp.linalg.eig", a)
|
||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
|
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
|
||||||
return w, v
|
return w, v
|
||||||
@ -369,6 +375,7 @@ def eig(a: ArrayLike) -> Tuple[Array, Array]:
|
|||||||
@_wraps(np.linalg.eigvals)
|
@_wraps(np.linalg.eigvals)
|
||||||
@jit
|
@jit
|
||||||
def eigvals(a: ArrayLike) -> Array:
|
def eigvals(a: ArrayLike) -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.eigvals", a)
|
||||||
return lax_linalg.eig(a, compute_left_eigenvectors=False,
|
return lax_linalg.eig(a, compute_left_eigenvectors=False,
|
||||||
compute_right_eigenvectors=False)[0]
|
compute_right_eigenvectors=False)[0]
|
||||||
|
|
||||||
@ -377,6 +384,7 @@ def eigvals(a: ArrayLike) -> Array:
|
|||||||
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
||||||
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
|
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
|
||||||
symmetrize_input: bool = True) -> Tuple[Array, Array]:
|
symmetrize_input: bool = True) -> Tuple[Array, Array]:
|
||||||
|
_check_arraylike("jnp.linalg.eigh", a)
|
||||||
if UPLO is None or UPLO == "L":
|
if UPLO is None or UPLO == "L":
|
||||||
lower = True
|
lower = True
|
||||||
elif UPLO == "U":
|
elif UPLO == "U":
|
||||||
@ -393,6 +401,7 @@ def eigh(a: ArrayLike, UPLO: Optional[str] = None,
|
|||||||
@_wraps(np.linalg.eigvalsh)
|
@_wraps(np.linalg.eigvalsh)
|
||||||
@partial(jit, static_argnames=('UPLO',))
|
@partial(jit, static_argnames=('UPLO',))
|
||||||
def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.eigvalsh", a)
|
||||||
w, _ = eigh(a, UPLO)
|
w, _ = eigh(a, UPLO)
|
||||||
return w
|
return w
|
||||||
|
|
||||||
@ -407,6 +416,7 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
|||||||
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
|
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
|
||||||
# Uses same algorithm as
|
# Uses same algorithm as
|
||||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||||
|
_check_arraylike("jnp.linalg.pinv", a)
|
||||||
arr = jnp.conj(a)
|
arr = jnp.conj(a)
|
||||||
if rcond is None:
|
if rcond is None:
|
||||||
max_rows_cols = max(arr.shape[-2:])
|
max_rows_cols = max(arr.shape[-2:])
|
||||||
@ -447,7 +457,7 @@ def _pinv_jvp(rcond, primals, tangents):
|
|||||||
@_wraps(np.linalg.inv)
|
@_wraps(np.linalg.inv)
|
||||||
@jit
|
@jit
|
||||||
def inv(a: ArrayLike) -> Array:
|
def inv(a: ArrayLike) -> Array:
|
||||||
# TODO(jakevdp): call _check_arraylike
|
_check_arraylike("jnp.linalg.inv", a)
|
||||||
arr = jnp.asarray(a)
|
arr = jnp.asarray(a)
|
||||||
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
|
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -461,6 +471,7 @@ def inv(a: ArrayLike) -> Array:
|
|||||||
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
|
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
|
||||||
axis: Union[None, Tuple[int, ...], int] = None,
|
axis: Union[None, Tuple[int, ...], int] = None,
|
||||||
keepdims: bool = False) -> Array:
|
keepdims: bool = False) -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.norm", x)
|
||||||
x, = _promote_dtypes_inexact(jnp.asarray(x))
|
x, = _promote_dtypes_inexact(jnp.asarray(x))
|
||||||
x_shape = jnp.shape(x)
|
x_shape = jnp.shape(x)
|
||||||
ndim = len(x_shape)
|
ndim = len(x_shape)
|
||||||
@ -560,6 +571,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]
|
|||||||
@_wraps(np.linalg.qr)
|
@_wraps(np.linalg.qr)
|
||||||
@partial(jit, static_argnames=('mode',))
|
@partial(jit, static_argnames=('mode',))
|
||||||
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
|
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
|
||||||
|
_check_arraylike("jnp.linalg.qr", a)
|
||||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||||
if mode == "raw":
|
if mode == "raw":
|
||||||
a, taus = lax_linalg.geqrf(a)
|
a, taus = lax_linalg.geqrf(a)
|
||||||
@ -579,6 +591,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]
|
|||||||
@_wraps(np.linalg.solve)
|
@_wraps(np.linalg.solve)
|
||||||
@jit
|
@jit
|
||||||
def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
||||||
|
_check_arraylike("jnp.linalg.solve", a, b)
|
||||||
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
|
||||||
return lax_linalg._solve(a, b)
|
return lax_linalg._solve(a, b)
|
||||||
|
|
||||||
@ -645,6 +658,7 @@ _jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
|
|||||||
"""))
|
"""))
|
||||||
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
|
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
|
||||||
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
||||||
|
_check_arraylike("jnp.linalg.lstsq", a, b)
|
||||||
if numpy_resid:
|
if numpy_resid:
|
||||||
return _lstsq(a, b, rcond, numpy_resid=True)
|
return _lstsq(a, b, rcond, numpy_resid=True)
|
||||||
return _jit_lstsq(a, b, rcond)
|
return _jit_lstsq(a, b, rcond)
|
||||||
|
178
jax/_src/third_party/numpy/linalg.py
vendored
178
jax/_src/third_party/numpy/linalg.py
vendored
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
|
|
||||||
from jax._src.numpy import lax_numpy as jnp
|
from jax._src.numpy import lax_numpy as jnp
|
||||||
from jax._src.numpy import linalg as la
|
from jax._src.numpy import linalg as la
|
||||||
from jax._src.numpy.util import _wraps
|
from jax._src.numpy.util import _check_arraylike, _wraps
|
||||||
|
|
||||||
|
|
||||||
def _isEmpty2d(arr):
|
def _isEmpty2d(arr):
|
||||||
@ -33,14 +33,15 @@ def _assertNdSquareness(*arrays):
|
|||||||
|
|
||||||
|
|
||||||
def _assert2d(*arrays):
|
def _assert2d(*arrays):
|
||||||
for a in arrays:
|
for a in arrays:
|
||||||
if a.ndim != 2:
|
if a.ndim != 2:
|
||||||
raise ValueError(f'{a.ndim}-dimensional array given. '
|
raise ValueError(f'{a.ndim}-dimensional array given. '
|
||||||
'Array must be two-dimensional')
|
'Array must be two-dimensional')
|
||||||
|
|
||||||
|
|
||||||
@_wraps(np.linalg.cond)
|
@_wraps(np.linalg.cond)
|
||||||
def cond(x, p=None):
|
def cond(x, p=None):
|
||||||
|
_check_arraylike('jnp.linalg.cond', x)
|
||||||
_assertNoEmpty2d(x)
|
_assertNoEmpty2d(x)
|
||||||
if p in (None, 2):
|
if p in (None, 2):
|
||||||
s = la.svd(x, compute_uv=False)
|
s = la.svd(x, compute_uv=False)
|
||||||
@ -63,6 +64,7 @@ def cond(x, p=None):
|
|||||||
|
|
||||||
@_wraps(np.linalg.tensorinv)
|
@_wraps(np.linalg.tensorinv)
|
||||||
def tensorinv(a, ind=2):
|
def tensorinv(a, ind=2):
|
||||||
|
_check_arraylike('jnp.linalg.tensorinv', a)
|
||||||
a = jnp.asarray(a)
|
a = jnp.asarray(a)
|
||||||
oldshape = a.shape
|
oldshape = a.shape
|
||||||
prod = 1
|
prod = 1
|
||||||
@ -79,6 +81,7 @@ def tensorinv(a, ind=2):
|
|||||||
|
|
||||||
@_wraps(np.linalg.tensorsolve)
|
@_wraps(np.linalg.tensorsolve)
|
||||||
def tensorsolve(a, b, axes=None):
|
def tensorsolve(a, b, axes=None):
|
||||||
|
_check_arraylike('jnp.linalg.tensorsolve', a, b)
|
||||||
a = jnp.asarray(a)
|
a = jnp.asarray(a)
|
||||||
b = jnp.asarray(b)
|
b = jnp.asarray(b)
|
||||||
an = a.ndim
|
an = a.ndim
|
||||||
@ -107,101 +110,102 @@ def tensorsolve(a, b, axes=None):
|
|||||||
|
|
||||||
@_wraps(np.linalg.multi_dot)
|
@_wraps(np.linalg.multi_dot)
|
||||||
def multi_dot(arrays, *, precision=None):
|
def multi_dot(arrays, *, precision=None):
|
||||||
n = len(arrays)
|
_check_arraylike('jnp.linalg.multi_dot', *arrays)
|
||||||
# optimization only makes sense for len(arrays) > 2
|
n = len(arrays)
|
||||||
if n < 2:
|
# optimization only makes sense for len(arrays) > 2
|
||||||
raise ValueError("Expecting at least two arrays.")
|
if n < 2:
|
||||||
elif n == 2:
|
raise ValueError("Expecting at least two arrays.")
|
||||||
return jnp.dot(arrays[0], arrays[1], precision=precision)
|
elif n == 2:
|
||||||
|
return jnp.dot(arrays[0], arrays[1], precision=precision)
|
||||||
|
|
||||||
arrays = [jnp.asarray(a) for a in arrays]
|
arrays = [jnp.asarray(a) for a in arrays]
|
||||||
|
|
||||||
# save original ndim to reshape the result array into the proper form later
|
# save original ndim to reshape the result array into the proper form later
|
||||||
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
|
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
|
||||||
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
|
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
|
||||||
# _multi_dot_* functions as simple as possible.
|
# _multi_dot_* functions as simple as possible.
|
||||||
if arrays[0].ndim == 1:
|
if arrays[0].ndim == 1:
|
||||||
arrays[0] = jnp.atleast_2d(arrays[0])
|
arrays[0] = jnp.atleast_2d(arrays[0])
|
||||||
if arrays[-1].ndim == 1:
|
if arrays[-1].ndim == 1:
|
||||||
arrays[-1] = jnp.atleast_2d(arrays[-1]).T
|
arrays[-1] = jnp.atleast_2d(arrays[-1]).T
|
||||||
_assert2d(*arrays)
|
_assert2d(*arrays)
|
||||||
|
|
||||||
# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
|
# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
|
||||||
if n == 3:
|
if n == 3:
|
||||||
result = _multi_dot_three(*arrays, precision)
|
result = _multi_dot_three(*arrays, precision)
|
||||||
else:
|
else:
|
||||||
order = _multi_dot_matrix_chain_order(arrays)
|
order = _multi_dot_matrix_chain_order(arrays)
|
||||||
result = _multi_dot(arrays, order, 0, n - 1, precision)
|
result = _multi_dot(arrays, order, 0, n - 1, precision)
|
||||||
|
|
||||||
# return proper shape
|
# return proper shape
|
||||||
if ndim_first == 1 and ndim_last == 1:
|
if ndim_first == 1 and ndim_last == 1:
|
||||||
return result[0, 0] # scalar
|
return result[0, 0] # scalar
|
||||||
elif ndim_first == 1 or ndim_last == 1:
|
elif ndim_first == 1 or ndim_last == 1:
|
||||||
return result.ravel() # 1-D
|
return result.ravel() # 1-D
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _multi_dot_three(A, B, C, precision):
|
def _multi_dot_three(A, B, C, precision):
|
||||||
"""
|
"""
|
||||||
Find the best order for three arrays and do the multiplication.
|
Find the best order for three arrays and do the multiplication.
|
||||||
For three arguments `_multi_dot_three` is approximately 15 times faster
|
For three arguments `_multi_dot_three` is approximately 15 times faster
|
||||||
than `_multi_dot_matrix_chain_order`
|
than `_multi_dot_matrix_chain_order`
|
||||||
"""
|
"""
|
||||||
a0, a1b0 = A.shape
|
a0, a1b0 = A.shape
|
||||||
b1c0, c1 = C.shape
|
b1c0, c1 = C.shape
|
||||||
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
|
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
|
||||||
cost1 = a0 * b1c0 * (a1b0 + c1)
|
cost1 = a0 * b1c0 * (a1b0 + c1)
|
||||||
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
|
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
|
||||||
cost2 = a1b0 * c1 * (a0 + b1c0)
|
cost2 = a1b0 * c1 * (a0 + b1c0)
|
||||||
|
|
||||||
if cost1 < cost2:
|
if cost1 < cost2:
|
||||||
return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
|
return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
|
||||||
else:
|
else:
|
||||||
return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)
|
return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)
|
||||||
|
|
||||||
|
|
||||||
def _multi_dot_matrix_chain_order(arrays, return_costs=False):
|
def _multi_dot_matrix_chain_order(arrays, return_costs=False):
|
||||||
"""
|
"""
|
||||||
Return a jnp.array that encodes the optimal order of mutiplications.
|
Return a jnp.array that encodes the optimal order of mutiplications.
|
||||||
The optimal order array is then used by `_multi_dot()` to do the
|
The optimal order array is then used by `_multi_dot()` to do the
|
||||||
multiplication.
|
multiplication.
|
||||||
Also return the cost matrix if `return_costs` is `True`
|
Also return the cost matrix if `return_costs` is `True`
|
||||||
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
|
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
|
||||||
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
|
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
|
||||||
cost[i, j] = min([
|
cost[i, j] = min([
|
||||||
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
|
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
|
||||||
for k in range(i, j)])
|
for k in range(i, j)])
|
||||||
"""
|
"""
|
||||||
n = len(arrays)
|
n = len(arrays)
|
||||||
# p stores the dimensions of the matrices
|
# p stores the dimensions of the matrices
|
||||||
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
|
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
|
||||||
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
|
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
|
||||||
# m is a matrix of costs of the subproblems
|
# m is a matrix of costs of the subproblems
|
||||||
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
|
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
|
||||||
m = np.zeros((n, n), dtype=np.double)
|
m = np.zeros((n, n), dtype=np.double)
|
||||||
# s is the actual ordering
|
# s is the actual ordering
|
||||||
# s[i, j] is the value of k at which we split the product A_i..A_j
|
# s[i, j] is the value of k at which we split the product A_i..A_j
|
||||||
s = np.empty((n, n), dtype=np.intp)
|
s = np.empty((n, n), dtype=np.intp)
|
||||||
|
|
||||||
for l in range(1, n):
|
for l in range(1, n):
|
||||||
for i in range(n - l):
|
for i in range(n - l):
|
||||||
j = i + l
|
j = i + l
|
||||||
m[i, j] = jnp.inf
|
m[i, j] = jnp.inf
|
||||||
for k in range(i, j):
|
for k in range(i, j):
|
||||||
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
|
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
|
||||||
if q < m[i, j]:
|
if q < m[i, j]:
|
||||||
m[i, j] = q
|
m[i, j] = q
|
||||||
s[i, j] = k # Note that Cormen uses 1-based index
|
s[i, j] = k # Note that Cormen uses 1-based index
|
||||||
|
|
||||||
return (s, m) if return_costs else s
|
return (s, m) if return_costs else s
|
||||||
|
|
||||||
|
|
||||||
def _multi_dot(arrays, order, i, j, precision):
|
def _multi_dot(arrays, order, i, j, precision):
|
||||||
"""Actually do the multiplication with the given order."""
|
"""Actually do the multiplication with the given order."""
|
||||||
if i == j:
|
if i == j:
|
||||||
return arrays[i]
|
return arrays[i]
|
||||||
else:
|
else:
|
||||||
return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
|
return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
|
||||||
_multi_dot(arrays, order, order[i, j] + 1, j, precision),
|
_multi_dot(arrays, order, order[i, j] + 1, j, precision),
|
||||||
precision=precision)
|
precision=precision)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user