mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Implement jnp.linalg.multi_dot using opt_einsum
This commit is contained in:
parent
5f702674f7
commit
09810be0cd
@ -1,3 +1,5 @@
|
||||
(ahead-of-time-lowering)=
|
||||
|
||||
# Ahead-of-time lowering and compilation
|
||||
|
||||
JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
|
||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
|
||||
@ -28,6 +29,7 @@ from jax import jit, custom_jvp
|
||||
from jax import lax
|
||||
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions, ufuncs
|
||||
@ -1924,3 +1926,95 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None)
|
||||
f" got a.shape={a_arr.shape}, b.ndim={b_arr.ndim}.")
|
||||
a_arr = a_arr.reshape(b_arr.size, math.prod(out_shape))
|
||||
return solve(a_arr, b_arr.ravel()).reshape(out_shape)
|
||||
|
||||
|
||||
def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array:
|
||||
"""Efficiently compute matrix products between a sequence of arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.multi_dot`.
|
||||
|
||||
JAX internally uses the opt_einsum library to compute the most efficient
|
||||
operation order.
|
||||
|
||||
Args:
|
||||
arrays: sequence of arrays. All must be two-dimensional, except the first
|
||||
and last which may be one-dimensional.
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
an array representing the equivalent of ``reduce(jnp.matmul, arrays)``, but
|
||||
evaluated in the optimal order.
|
||||
|
||||
This function exists because the cost of computing sequences of matmul operations
|
||||
can differ vastly depending on the order in which the operations are evaluated.
|
||||
For a single matmul, the number of floating point operations (flops) required to
|
||||
compute a matrix product can be approximated this way:
|
||||
|
||||
>>> def approx_flops(x, y):
|
||||
... # for 2D x and y, with x.shape[1] == y.shape[0]
|
||||
... return 2 * x.shape[0] * x.shape[1] * y.shape[1]
|
||||
|
||||
Suppose we have three matrices that we'd like to multiply in sequence:
|
||||
|
||||
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
|
||||
>>> x = jax.random.normal(key1, shape=(200, 5))
|
||||
>>> y = jax.random.normal(key2, shape=(5, 100))
|
||||
>>> z = jax.random.normal(key3, shape=(100, 10))
|
||||
|
||||
Because of associativity of matrix products, there are two orders in which we might
|
||||
evaluate the product ``x @ y @ z``, and both produce equivalent outputs up to floating
|
||||
point precision:
|
||||
|
||||
>>> result1 = (x @ y) @ z
|
||||
>>> result2 = x @ (y @ z)
|
||||
>>> jnp.allclose(result1, result2, atol=1E-4)
|
||||
Array(True, dtype=bool)
|
||||
|
||||
But the computational cost of these differ greatly:
|
||||
|
||||
>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
|
||||
(x @ y) @ z flops: 600000
|
||||
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
|
||||
x @ (y @ z) flops: 30000
|
||||
|
||||
The second approach is about 20x more efficient in terms of estimated flops!
|
||||
|
||||
``multi_dot`` is a function that will automatically choose the fastest
|
||||
computational path for such problems:
|
||||
|
||||
>>> result3 = jnp.linalg.multi_dot([x, y, z])
|
||||
>>> jnp.allclose(result1, result3, atol=1E-4)
|
||||
Array(True, dtype=bool)
|
||||
|
||||
We can use JAX's :ref:`ahead-of-time-lowering` tools to estimate the total flops
|
||||
of each approach, and confirm that ``multi_dot`` is choosing the more efficient
|
||||
option:
|
||||
|
||||
>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
|
||||
600000.0
|
||||
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
|
||||
30000.0
|
||||
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
|
||||
30000.0
|
||||
"""
|
||||
check_arraylike('jnp.linalg.multi_dot', *arrays)
|
||||
arrs: list[Array] = list(map(jnp.asarray, arrays))
|
||||
if len(arrs) < 2:
|
||||
raise ValueError(f"multi_dot requires at least two arrays; got len(arrays)={len(arrs)}")
|
||||
if not (arrs[0].ndim in (1, 2) and arrs[-1].ndim in (1, 2) and
|
||||
all(a.ndim == 2 for a in arrs[1:-1])):
|
||||
raise ValueError("multi_dot: input arrays must all be two-dimensional, except for"
|
||||
" the first and last array which may be 1 or 2 dimensional."
|
||||
f" Got array shapes {[a.shape for a in arrs]}")
|
||||
if any(a.shape[-1] != b.shape[0] for a, b in zip(arrs[:-1], arrs[1:])):
|
||||
raise ValueError("multi_dot: last dimension of each array must match first dimension"
|
||||
f" of following array. Got array shapes {[a.shape for a in arrs]}")
|
||||
einsum_axes: list[tuple[int, ...]] = [(i, i+1) for i in range(len(arrs))]
|
||||
if arrs[0].ndim == 1:
|
||||
einsum_axes[0] = einsum_axes[0][1:]
|
||||
if arrs[-1].ndim == 1:
|
||||
einsum_axes[-1] = einsum_axes[-1][:1]
|
||||
return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[arg-type, call-overload]
|
||||
optimize='optimal', precision=precision)
|
||||
|
110
jax/_src/third_party/numpy/linalg.py
vendored
110
jax/_src/third_party/numpy/linalg.py
vendored
@ -32,13 +32,6 @@ def _assertNdSquareness(*arrays):
|
||||
'Last 2 dimensions of the array must be square')
|
||||
|
||||
|
||||
def _assert2d(*arrays):
|
||||
for a in arrays:
|
||||
if a.ndim != 2:
|
||||
raise ValueError(f'{a.ndim}-dimensional array given. '
|
||||
'Array must be two-dimensional')
|
||||
|
||||
|
||||
@implements(np.linalg.cond)
|
||||
def cond(x, p=None):
|
||||
check_arraylike('jnp.linalg.cond', x)
|
||||
@ -60,106 +53,3 @@ def cond(x, p=None):
|
||||
nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
|
||||
r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
|
||||
return r
|
||||
|
||||
|
||||
@implements(np.linalg.multi_dot)
|
||||
def multi_dot(arrays, *, precision=None):
|
||||
check_arraylike('jnp.linalg.multi_dot', *arrays)
|
||||
n = len(arrays)
|
||||
# optimization only makes sense for len(arrays) > 2
|
||||
if n < 2:
|
||||
raise ValueError("Expecting at least two arrays.")
|
||||
elif n == 2:
|
||||
return jnp.dot(arrays[0], arrays[1], precision=precision)
|
||||
|
||||
arrays = [jnp.asarray(a) for a in arrays]
|
||||
|
||||
# save original ndim to reshape the result array into the proper form later
|
||||
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
|
||||
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
|
||||
# _multi_dot_* functions as simple as possible.
|
||||
if arrays[0].ndim == 1:
|
||||
arrays[0] = jnp.atleast_2d(arrays[0])
|
||||
if arrays[-1].ndim == 1:
|
||||
arrays[-1] = jnp.atleast_2d(arrays[-1]).T
|
||||
_assert2d(*arrays)
|
||||
|
||||
# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
|
||||
if n == 3:
|
||||
result = _multi_dot_three(*arrays, precision)
|
||||
else:
|
||||
order = _multi_dot_matrix_chain_order(arrays)
|
||||
result = _multi_dot(arrays, order, 0, n - 1, precision)
|
||||
|
||||
# return proper shape
|
||||
if ndim_first == 1 and ndim_last == 1:
|
||||
return result[0, 0] # scalar
|
||||
elif ndim_first == 1 or ndim_last == 1:
|
||||
return result.ravel() # 1-D
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def _multi_dot_three(A, B, C, precision):
|
||||
"""
|
||||
Find the best order for three arrays and do the multiplication.
|
||||
For three arguments `_multi_dot_three` is approximately 15 times faster
|
||||
than `_multi_dot_matrix_chain_order`
|
||||
"""
|
||||
a0, a1b0 = A.shape
|
||||
b1c0, c1 = C.shape
|
||||
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
|
||||
cost1 = a0 * b1c0 * (a1b0 + c1)
|
||||
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
|
||||
cost2 = a1b0 * c1 * (a0 + b1c0)
|
||||
|
||||
if cost1 < cost2:
|
||||
return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
|
||||
else:
|
||||
return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)
|
||||
|
||||
|
||||
def _multi_dot_matrix_chain_order(arrays, return_costs=False):
|
||||
"""
|
||||
Return a jnp.array that encodes the optimal order of mutiplications.
|
||||
The optimal order array is then used by `_multi_dot()` to do the
|
||||
multiplication.
|
||||
Also return the cost matrix if `return_costs` is `True`
|
||||
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
|
||||
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
|
||||
cost[i, j] = min([
|
||||
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
|
||||
for k in range(i, j)])
|
||||
"""
|
||||
n = len(arrays)
|
||||
# p stores the dimensions of the matrices
|
||||
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
|
||||
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
|
||||
# m is a matrix of costs of the subproblems
|
||||
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
|
||||
m = np.zeros((n, n), dtype=np.double)
|
||||
# s is the actual ordering
|
||||
# s[i, j] is the value of k at which we split the product A_i..A_j
|
||||
s = np.empty((n, n), dtype=np.intp)
|
||||
|
||||
for l in range(1, n):
|
||||
for i in range(n - l):
|
||||
j = i + l
|
||||
m[i, j] = jnp.inf
|
||||
for k in range(i, j):
|
||||
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
|
||||
if q < m[i, j]:
|
||||
m[i, j] = q
|
||||
s[i, j] = k # Note that Cormen uses 1-based index
|
||||
|
||||
return (s, m) if return_costs else s
|
||||
|
||||
|
||||
def _multi_dot(arrays, order, i, j, precision):
|
||||
"""Actually do the multiplication with the given order."""
|
||||
if i == j:
|
||||
return arrays[i]
|
||||
else:
|
||||
return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
|
||||
_multi_dot(arrays, order, order[i, j] + 1, j, precision),
|
||||
precision=precision)
|
||||
|
@ -31,6 +31,7 @@ from jax._src.numpy.linalg import (
|
||||
matrix_power as matrix_power,
|
||||
matrix_rank as matrix_rank,
|
||||
matrix_transpose as matrix_transpose,
|
||||
multi_dot as multi_dot,
|
||||
norm as norm,
|
||||
outer as outer,
|
||||
pinv as pinv,
|
||||
@ -47,5 +48,4 @@ from jax._src.numpy.linalg import (
|
||||
)
|
||||
from jax._src.third_party.numpy.linalg import (
|
||||
cond as cond,
|
||||
multi_dot as multi_dot,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user