mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
refactor: move jnp.einsum impl into its own submodule
This commit is contained in:
parent
837418c652
commit
7ab7b214ac
@ -42,7 +42,7 @@ from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.numpy import einsum as jnp_einsum
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
@ -1267,7 +1267,7 @@ def _einsum_contract_path(*operands, **kwargs):
|
||||
contract_operands.append(operands[idx[0]])
|
||||
return contract_operands, contractions
|
||||
|
||||
lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
|
||||
jnp_einsum._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
|
||||
|
||||
# To implement shape-constraint checking we use a shape assertion primitive.
|
||||
# shape_assertion_p.bind(assert_what: bool, *error_message_inputs,
|
||||
|
576
jax/_src/numpy/einsum.py
Normal file
576
jax/_src/numpy/einsum.py
Normal file
@ -0,0 +1,576 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from typing import overload, Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import jit, named_call
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.numpy import util
|
||||
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding, PartitionSpec as P
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import partition_list, set_module, unzip2
|
||||
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
|
||||
class Unoptimized(opt_einsum.paths.PathOptimizer):
|
||||
"""Unoptimized path for einsum."""
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
return [(0, 1)] * (len(inputs) - 1)
|
||||
|
||||
@overload
|
||||
def einsum(
|
||||
subscript: str, /,
|
||||
*operands: ArrayLike,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_sharding=None,
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
def einsum(
|
||||
arr: ArrayLike,
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_sharding=None,
|
||||
) -> Array: ...
|
||||
|
||||
@export
|
||||
def einsum(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_sharding=None,
|
||||
) -> Array:
|
||||
"""Einstein summation
|
||||
|
||||
JAX implementation of :func:`numpy.einsum`.
|
||||
|
||||
``einsum`` is a powerful and generic API for computing various reductions,
|
||||
inner products, outer products, axis reorderings, and combinations thereof
|
||||
across one or more input arrays. It has a somewhat complicated overloaded API;
|
||||
the arguments below reflect the most common calling convention. The Examples
|
||||
section below demonstrates some of the alternative calling conventions.
|
||||
|
||||
Args:
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"auto"`` which produces optimized expressions via the opt_einsum_
|
||||
package. Other options are ``True`` (same as ``"optimal"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
preferred_element_type: either ``None`` (default), which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
out: unsupported by JAX
|
||||
_dot_general: optionally override the ``dot_general`` callable used by ``einsum``.
|
||||
This parameter is experimental, and may be removed without warning at any time.
|
||||
|
||||
Returns:
|
||||
array containing the result of the einstein summation.
|
||||
|
||||
See also:
|
||||
:func:`jax.numpy.einsum_path`
|
||||
|
||||
Examples:
|
||||
The mechanics of ``einsum`` are perhaps best demonstrated by example. Here we
|
||||
show how to use ``einsum`` to compute a number of quantities from one or more
|
||||
arrays. For more discussion and examples of ``einsum``, see the documentation
|
||||
of :func:`numpy.einsum`.
|
||||
|
||||
>>> M = jnp.arange(16).reshape(4, 4)
|
||||
>>> x = jnp.arange(4)
|
||||
>>> y = jnp.array([5, 4, 3, 2])
|
||||
|
||||
**Vector product**
|
||||
|
||||
>>> jnp.einsum('i,i', x, y)
|
||||
Array(16, dtype=int32)
|
||||
>>> jnp.vecdot(x, y)
|
||||
Array(16, dtype=int32)
|
||||
|
||||
Here are some alternative ``einsum`` calling conventions to compute the same
|
||||
result:
|
||||
|
||||
>>> jnp.einsum('i,i->', x, y) # explicit form
|
||||
Array(16, dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices
|
||||
Array(16, dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices
|
||||
Array(16, dtype=int32)
|
||||
|
||||
**Matrix product**
|
||||
|
||||
>>> jnp.einsum('ij,j->i', M, x) # explicit form
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
>>> jnp.matmul(M, x)
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
|
||||
Here are some alternative ``einsum`` calling conventions to compute the same
|
||||
result:
|
||||
|
||||
>>> jnp.einsum('ij,j', M, x) # implicit form
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
>>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
>>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
|
||||
**Outer product**
|
||||
|
||||
>>> jnp.einsum("i,j->ij", x, y)
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
>>> jnp.outer(x, y)
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
|
||||
Some other ways of computing outer products:
|
||||
|
||||
>>> jnp.einsum("i,j", x, y) # implicit form
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
|
||||
**1D array sum**
|
||||
|
||||
>>> jnp.einsum("i->", x) # requires explicit form
|
||||
Array(6, dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), ()) # explicit form via indices
|
||||
Array(6, dtype=int32)
|
||||
>>> jnp.sum(x)
|
||||
Array(6, dtype=int32)
|
||||
|
||||
**Sum along an axis**
|
||||
|
||||
>>> jnp.einsum("...j->...", M) # requires explicit form
|
||||
Array([ 6, 22, 38, 54], dtype=int32)
|
||||
>>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices
|
||||
Array([ 6, 22, 38, 54], dtype=int32)
|
||||
>>> M.sum(-1)
|
||||
Array([ 6, 22, 38, 54], dtype=int32)
|
||||
|
||||
**Matrix transpose**
|
||||
|
||||
>>> y = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> jnp.einsum("ij->ji", y) # explicit form
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.einsum("ji", y) # implicit form
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.einsum(y, (1, 0)) # implicit form via indices
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.transpose(y)
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
|
||||
**Matrix diagonal**
|
||||
|
||||
>>> jnp.einsum("ii->i", M)
|
||||
Array([ 0, 5, 10, 15], dtype=int32)
|
||||
>>> jnp.diagonal(M)
|
||||
Array([ 0, 5, 10, 15], dtype=int32)
|
||||
|
||||
**Matrix trace**
|
||||
|
||||
>>> jnp.einsum("ii", M)
|
||||
Array(30, dtype=int32)
|
||||
>>> jnp.trace(M)
|
||||
Array(30, dtype=int32)
|
||||
|
||||
**Tensor products**
|
||||
|
||||
>>> x = jnp.arange(30).reshape(2, 3, 5)
|
||||
>>> y = jnp.arange(60).reshape(3, 4, 5)
|
||||
>>> jnp.einsum('ijk,jlk->il', x, y) # explicit form
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)])
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.einsum('ijk,jlk', x, y) # implicit form
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
|
||||
**Chained dot products**
|
||||
|
||||
>>> w = jnp.arange(5, 9).reshape(2, 2)
|
||||
>>> x = jnp.arange(6).reshape(2, 3)
|
||||
>>> y = jnp.arange(-2, 4).reshape(3, 2)
|
||||
>>> z = jnp.array([[2, 4, 6], [3, 5, 7]])
|
||||
>>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z)
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
>>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
>>> w @ x @ y @ z # direct chain of matmuls
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
>>> jnp.linalg.multi_dot([w, x, y, z])
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
|
||||
.. _opt_einsum: https://github.com/dgasmith/opt_einsum
|
||||
"""
|
||||
operands = (subscripts, *operands)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
|
||||
spec = operands[0] if isinstance(operands[0], str) else None
|
||||
path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize
|
||||
|
||||
# Allow handling of shape polymorphism
|
||||
non_constant_dim_types = {
|
||||
type(d) for op in operands if not isinstance(op, str)
|
||||
for d in np.shape(op) if not core.is_constant_dim(d)
|
||||
}
|
||||
if not non_constant_dim_types:
|
||||
contract_path = opt_einsum.contract_path
|
||||
else:
|
||||
ty = next(iter(non_constant_dim_types))
|
||||
contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
|
||||
# using einsum_call=True here is an internal api for opt_einsum... sorry
|
||||
operands, contractions = contract_path(
|
||||
*operands, einsum_call=True, use_blas=True, optimize=path_type)
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
if spec is not None:
|
||||
jit_einsum = named_call(jit_einsum, name=spec)
|
||||
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
|
||||
return jit_einsum(operand_arrays, contractions, precision,
|
||||
preferred_element_type, _dot_general, out_sharding)
|
||||
|
||||
|
||||
# Enable other modules to override einsum_contact_path.
|
||||
# Indexed by the type of the non constant dimension
|
||||
_poly_einsum_handlers = {} # type: ignore
|
||||
|
||||
def _default_poly_einsum_handler(*operands, **kwargs):
|
||||
dummy = collections.namedtuple('dummy', ['shape', 'dtype'])
|
||||
dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype)
|
||||
if hasattr(x, 'dtype') else x for x in operands]
|
||||
mapping = {id(d): i for i, d in enumerate(dummies)}
|
||||
out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs)
|
||||
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
|
||||
return contract_operands, contractions
|
||||
|
||||
@overload
|
||||
def einsum_path(
|
||||
subscripts: str, /,
|
||||
*operands: ArrayLike,
|
||||
optimize: bool | str | list[tuple[int, ...]] = ...,
|
||||
) -> tuple[list[tuple[int, ...]], Any]: ...
|
||||
|
||||
@overload
|
||||
def einsum_path(
|
||||
arr: ArrayLike,
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
optimize: bool | str | list[tuple[int, ...]] = ...,
|
||||
) -> tuple[list[tuple[int, ...]], Any]: ...
|
||||
|
||||
@export
|
||||
def einsum_path(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
optimize: bool | str | list[tuple[int, ...]] = 'auto'
|
||||
) -> tuple[list[tuple[int, ...]], Any]:
|
||||
"""Evaluates the optimal contraction path without evaluating the einsum.
|
||||
|
||||
JAX implementation of :func:`numpy.einsum_path`. This function calls into
|
||||
the opt_einsum_ package, and makes use of its optimization routines.
|
||||
|
||||
Args:
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"auto"``. Other options are ``True`` (same as ``"optimize"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"optimize"``,, ``"greedy"``, ``"eager"``, and others.
|
||||
|
||||
Returns:
|
||||
A tuple containing the path that may be passed to :func:`~jax.numpy.einsum`, and a
|
||||
printable object representing this optimal path.
|
||||
|
||||
Examples:
|
||||
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
|
||||
>>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3))
|
||||
>>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100))
|
||||
>>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5))
|
||||
>>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal")
|
||||
>>> print(path)
|
||||
[(1, 2), (0, 1)]
|
||||
>>> print(path_info)
|
||||
Complete contraction: ij,jk,kl->il
|
||||
Naive scaling: 4
|
||||
Optimized scaling: 3
|
||||
Naive FLOP count: 9.000e+3
|
||||
Optimized FLOP count: 3.060e+3
|
||||
Theoretical speedup: 2.941e+0
|
||||
Largest intermediate: 1.500e+1 elements
|
||||
--------------------------------------------------------------------------------
|
||||
scaling BLAS current remaining
|
||||
--------------------------------------------------------------------------------
|
||||
3 GEMM kl,jk->lj ij,lj->il
|
||||
3 GEMM lj,ij->il il->il
|
||||
|
||||
Use the computed path in :func:`~jax.numpy.einsum`:
|
||||
|
||||
>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path)
|
||||
Array([[-754, 324, -142, 82, 50],
|
||||
[ 408, -50, 87, -29, 7]], dtype=int32)
|
||||
|
||||
.. _opt_einsum: https://github.com/dgasmith/opt_einsum
|
||||
"""
|
||||
if optimize is True:
|
||||
optimize = 'optimal'
|
||||
elif optimize is False:
|
||||
optimize = Unoptimized()
|
||||
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
|
||||
|
||||
def _removechars(s, chars):
|
||||
return s.translate(str.maketrans(dict.fromkeys(chars)))
|
||||
|
||||
|
||||
def _einsum(
|
||||
operands: list[Array],
|
||||
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
|
||||
precision,
|
||||
preferred_element_type,
|
||||
_dot_general=lax.dot_general,
|
||||
out_sharding=None,
|
||||
):
|
||||
if out_sharding is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_sharding only works when sharding_in_types "
|
||||
"config is True.")
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
"`out_sharding` argument of `einsum` only supports NamedSharding"
|
||||
" instances. Please file a bug if this is not enough for your use case.")
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
|
||||
else:
|
||||
output_weak_type = False
|
||||
|
||||
def sum(x, axes):
|
||||
if dtypes.result_type(x, preferred_element_type) != x.dtype:
|
||||
x = x.astype(preferred_element_type)
|
||||
return lax.reduce(x, np.array(0, x.dtype),
|
||||
lax.add if x.dtype != bool else lax.bitwise_or, axes)
|
||||
|
||||
def sum_uniques(operand, names, uniques):
|
||||
if uniques:
|
||||
axes = [names.index(name) for name in uniques]
|
||||
operand = sum(operand, axes)
|
||||
names = _removechars(names, uniques)
|
||||
return operand, names
|
||||
|
||||
def sum_repeats(operand, names, counts, keep_names):
|
||||
for name, count in counts.items():
|
||||
if count > 1:
|
||||
axes = [i for i, n in enumerate(names) if n == name]
|
||||
eye = lax._delta(np.dtype('bool'), operand.shape, axes)
|
||||
operand = lax.select(eye, operand, lax.full_like(operand, 0))
|
||||
if name not in keep_names:
|
||||
operand = sum(operand, axes)
|
||||
names = names.replace(name, '')
|
||||
else:
|
||||
operand = sum(operand, axes[:-1])
|
||||
names = names.replace(name, '', count - 1)
|
||||
return operand, names
|
||||
|
||||
def filter_singleton_dims(operand, names, other_shape, other_names):
|
||||
eq = core.definitely_equal
|
||||
keep = [not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
|
||||
for i, j in enumerate(map(other_names.find, names))]
|
||||
sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))
|
||||
return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes)
|
||||
|
||||
for operand_indices, contracted_names_set, einstr in contractions:
|
||||
contracted_names = sorted(contracted_names_set)
|
||||
input_str, result_names = einstr.split('->')
|
||||
input_names = input_str.split(',')
|
||||
|
||||
# switch on the number of operands to be processed in this loop iteration.
|
||||
# every case here sets 'operand' and 'names'.
|
||||
if len(operand_indices) == 1:
|
||||
operand = operands.pop(operand_indices[0])
|
||||
names, = input_names
|
||||
counts = collections.Counter(names)
|
||||
|
||||
# sum out unique contracted indices with a single reduce-sum
|
||||
uniques = [name for name in contracted_names if counts[name] == 1]
|
||||
operand, names = sum_uniques(operand, names, uniques)
|
||||
|
||||
# for every repeated index, do a contraction against an identity matrix
|
||||
operand, names = sum_repeats(operand, names, counts, result_names)
|
||||
|
||||
elif len(operand_indices) == 2:
|
||||
lhs, rhs = map(operands.pop, operand_indices)
|
||||
lhs_names, rhs_names = input_names
|
||||
|
||||
# handle cases where one side of a contracting or batch dimension is 1
|
||||
# but its counterpart is not.
|
||||
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, np.shape(rhs),
|
||||
rhs_names)
|
||||
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, np.shape(lhs),
|
||||
lhs_names)
|
||||
|
||||
lhs_counts = collections.Counter(lhs_names)
|
||||
rhs_counts = collections.Counter(rhs_names)
|
||||
|
||||
# sum out unique contracted indices in lhs and rhs
|
||||
lhs_uniques = [name for name in contracted_names
|
||||
if lhs_counts[name] == 1 and rhs_counts[name] == 0]
|
||||
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
|
||||
|
||||
rhs_uniques = [name for name in contracted_names
|
||||
if rhs_counts[name] == 1 and lhs_counts[name] == 0]
|
||||
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
|
||||
|
||||
# for every repeated index, contract against an identity matrix
|
||||
lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts,
|
||||
result_names + rhs_names)
|
||||
rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts,
|
||||
result_names + lhs_names)
|
||||
|
||||
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
|
||||
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
|
||||
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
|
||||
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
|
||||
|
||||
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
|
||||
for n in batch_names)
|
||||
|
||||
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
|
||||
# due to opt_einsum
|
||||
assert config.dynamic_shapes.value or all(
|
||||
name in lhs_names and name in rhs_names and
|
||||
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
|
||||
for name in contracted_names), (
|
||||
"Incompatible reduction dimensions: "
|
||||
f"lhs.shape={lhs.shape} lhs_names={lhs_names} "
|
||||
f"rhs.shape={rhs.shape} rhs_names={rhs_names}")
|
||||
|
||||
# contract using dot_general
|
||||
batch_names_str = ''.join(batch_names)
|
||||
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
|
||||
for n in contracted_names)
|
||||
deleted_names = batch_names_str + ''.join(contracted_names)
|
||||
remaining_lhs_names = _removechars(lhs_names, deleted_names)
|
||||
remaining_rhs_names = _removechars(rhs_names, deleted_names)
|
||||
# Try both orders of lhs and rhs, in the hope that one of them means we
|
||||
# don't need an explicit transpose. opt_einsum likes to contract from
|
||||
# right to left, so we expect (rhs,lhs) to have the best chance of not
|
||||
# needing a transpose.
|
||||
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
|
||||
if names == result_names:
|
||||
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
|
||||
k_out_sharding = ({} if out_sharding is None else
|
||||
{'out_sharding': out_sharding})
|
||||
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
**k_out_sharding)
|
||||
else:
|
||||
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
|
||||
if (config.sharding_in_types.value and out_sharding is not None and
|
||||
names != result_names):
|
||||
spec = out_sharding.spec
|
||||
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
|
||||
dot_general_out_sharding = NamedSharding(out_sharding.mesh,
|
||||
P(*inverse_spec))
|
||||
else:
|
||||
dot_general_out_sharding = out_sharding # type: ignore
|
||||
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
|
||||
dot_general_out_sharding = ({} if dot_general_out_sharding is None else # type: ignore
|
||||
{'out_sharding': dot_general_out_sharding})
|
||||
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
**dot_general_out_sharding)
|
||||
else:
|
||||
raise NotImplementedError # if this is actually reachable, open an issue!
|
||||
|
||||
# the resulting 'operand' with axis labels 'names' should be a permutation
|
||||
# of the desired result
|
||||
assert len(names) == len(result_names) == len(set(names))
|
||||
assert set(names) == set(result_names)
|
||||
if names != result_names:
|
||||
perm = tuple(names.index(name) for name in result_names)
|
||||
operand = lax.transpose(operand, perm)
|
||||
operands.append(operand) # used in next iteration
|
||||
|
||||
return lax._convert_element_type(operands[0], preferred_element_type,
|
||||
output_weak_type)
|
@ -26,7 +26,6 @@ rules for the underlying :code:`lax` primitives.
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import importlib
|
||||
@ -63,6 +62,7 @@ from jax._src.numpy import reductions
|
||||
from jax._src.numpy import tensor_contractions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
from jax._src.numpy.einsum import einsum
|
||||
from jax._src.numpy.sorting import argsort, sort
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import (
|
||||
@ -70,14 +70,12 @@ from jax._src.typing import (
|
||||
)
|
||||
from jax._src.util import (
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
|
||||
ceil_of_ratio, safe_zip, set_module, unzip2,
|
||||
tuple_replace)
|
||||
from jax.sharding import Sharding
|
||||
from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax.tree_util import tree_flatten, tree_leaves, tree_map
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
@ -8546,548 +8544,6 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
|
||||
raise ValueError("function is not returning an array of the correct shape")
|
||||
return a_arr
|
||||
|
||||
class Unoptimized(opt_einsum.paths.PathOptimizer):
|
||||
"""Unoptimized path for einsum."""
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
return [(0, 1)] * (len(inputs) - 1)
|
||||
|
||||
@overload
|
||||
def einsum(
|
||||
subscript: str, /,
|
||||
*operands: ArrayLike,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_sharding=None,
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
def einsum(
|
||||
arr: ArrayLike,
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_sharding=None,
|
||||
) -> Array: ...
|
||||
|
||||
@export
|
||||
def einsum(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_sharding=None,
|
||||
) -> Array:
|
||||
"""Einstein summation
|
||||
|
||||
JAX implementation of :func:`numpy.einsum`.
|
||||
|
||||
``einsum`` is a powerful and generic API for computing various reductions,
|
||||
inner products, outer products, axis reorderings, and combinations thereof
|
||||
across one or more input arrays. It has a somewhat complicated overloaded API;
|
||||
the arguments below reflect the most common calling convention. The Examples
|
||||
section below demonstrates some of the alternative calling conventions.
|
||||
|
||||
Args:
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"auto"`` which produces optimized expressions via the opt_einsum_
|
||||
package. Other options are ``True`` (same as ``"optimal"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
preferred_element_type: either ``None`` (default), which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
out: unsupported by JAX
|
||||
_dot_general: optionally override the ``dot_general`` callable used by ``einsum``.
|
||||
This parameter is experimental, and may be removed without warning at any time.
|
||||
|
||||
Returns:
|
||||
array containing the result of the einstein summation.
|
||||
|
||||
See also:
|
||||
:func:`jax.numpy.einsum_path`
|
||||
|
||||
Examples:
|
||||
The mechanics of ``einsum`` are perhaps best demonstrated by example. Here we
|
||||
show how to use ``einsum`` to compute a number of quantities from one or more
|
||||
arrays. For more discussion and examples of ``einsum``, see the documentation
|
||||
of :func:`numpy.einsum`.
|
||||
|
||||
>>> M = jnp.arange(16).reshape(4, 4)
|
||||
>>> x = jnp.arange(4)
|
||||
>>> y = jnp.array([5, 4, 3, 2])
|
||||
|
||||
**Vector product**
|
||||
|
||||
>>> jnp.einsum('i,i', x, y)
|
||||
Array(16, dtype=int32)
|
||||
>>> jnp.vecdot(x, y)
|
||||
Array(16, dtype=int32)
|
||||
|
||||
Here are some alternative ``einsum`` calling conventions to compute the same
|
||||
result:
|
||||
|
||||
>>> jnp.einsum('i,i->', x, y) # explicit form
|
||||
Array(16, dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices
|
||||
Array(16, dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices
|
||||
Array(16, dtype=int32)
|
||||
|
||||
**Matrix product**
|
||||
|
||||
>>> jnp.einsum('ij,j->i', M, x) # explicit form
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
>>> jnp.matmul(M, x)
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
|
||||
Here are some alternative ``einsum`` calling conventions to compute the same
|
||||
result:
|
||||
|
||||
>>> jnp.einsum('ij,j', M, x) # implicit form
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
>>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
>>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices
|
||||
Array([14, 38, 62, 86], dtype=int32)
|
||||
|
||||
**Outer product**
|
||||
|
||||
>>> jnp.einsum("i,j->ij", x, y)
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
>>> jnp.outer(x, y)
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
|
||||
Some other ways of computing outer products:
|
||||
|
||||
>>> jnp.einsum("i,j", x, y) # implicit form
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices
|
||||
Array([[ 0, 0, 0, 0],
|
||||
[ 5, 4, 3, 2],
|
||||
[10, 8, 6, 4],
|
||||
[15, 12, 9, 6]], dtype=int32)
|
||||
|
||||
**1D array sum**
|
||||
|
||||
>>> jnp.einsum("i->", x) # requires explicit form
|
||||
Array(6, dtype=int32)
|
||||
>>> jnp.einsum(x, (0,), ()) # explicit form via indices
|
||||
Array(6, dtype=int32)
|
||||
>>> jnp.sum(x)
|
||||
Array(6, dtype=int32)
|
||||
|
||||
**Sum along an axis**
|
||||
|
||||
>>> jnp.einsum("...j->...", M) # requires explicit form
|
||||
Array([ 6, 22, 38, 54], dtype=int32)
|
||||
>>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices
|
||||
Array([ 6, 22, 38, 54], dtype=int32)
|
||||
>>> M.sum(-1)
|
||||
Array([ 6, 22, 38, 54], dtype=int32)
|
||||
|
||||
**Matrix transpose**
|
||||
|
||||
>>> y = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> jnp.einsum("ij->ji", y) # explicit form
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.einsum("ji", y) # implicit form
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.einsum(y, (1, 0)) # implicit form via indices
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
>>> jnp.transpose(y)
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
|
||||
**Matrix diagonal**
|
||||
|
||||
>>> jnp.einsum("ii->i", M)
|
||||
Array([ 0, 5, 10, 15], dtype=int32)
|
||||
>>> jnp.diagonal(M)
|
||||
Array([ 0, 5, 10, 15], dtype=int32)
|
||||
|
||||
**Matrix trace**
|
||||
|
||||
>>> jnp.einsum("ii", M)
|
||||
Array(30, dtype=int32)
|
||||
>>> jnp.trace(M)
|
||||
Array(30, dtype=int32)
|
||||
|
||||
**Tensor products**
|
||||
|
||||
>>> x = jnp.arange(30).reshape(2, 3, 5)
|
||||
>>> y = jnp.arange(60).reshape(3, 4, 5)
|
||||
>>> jnp.einsum('ijk,jlk->il', x, y) # explicit form
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)])
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.einsum('ijk,jlk', x, y) # implicit form
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices
|
||||
Array([[ 3340, 3865, 4390, 4915],
|
||||
[ 8290, 9940, 11590, 13240]], dtype=int32)
|
||||
|
||||
**Chained dot products**
|
||||
|
||||
>>> w = jnp.arange(5, 9).reshape(2, 2)
|
||||
>>> x = jnp.arange(6).reshape(2, 3)
|
||||
>>> y = jnp.arange(-2, 4).reshape(3, 2)
|
||||
>>> z = jnp.array([[2, 4, 6], [3, 5, 7]])
|
||||
>>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z)
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
>>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
>>> w @ x @ y @ z # direct chain of matmuls
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
>>> jnp.linalg.multi_dot([w, x, y, z])
|
||||
Array([[ 481, 831, 1181],
|
||||
[ 651, 1125, 1599]], dtype=int32)
|
||||
|
||||
.. _opt_einsum: https://github.com/dgasmith/opt_einsum
|
||||
"""
|
||||
operands = (subscripts, *operands)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
|
||||
spec = operands[0] if isinstance(operands[0], str) else None
|
||||
path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize
|
||||
|
||||
# Allow handling of shape polymorphism
|
||||
non_constant_dim_types = {
|
||||
type(d) for op in operands if not isinstance(op, str)
|
||||
for d in np.shape(op) if not core.is_constant_dim(d)
|
||||
}
|
||||
if not non_constant_dim_types:
|
||||
contract_path = opt_einsum.contract_path
|
||||
else:
|
||||
ty = next(iter(non_constant_dim_types))
|
||||
contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
|
||||
# using einsum_call=True here is an internal api for opt_einsum... sorry
|
||||
operands, contractions = contract_path(
|
||||
*operands, einsum_call=True, use_blas=True, optimize=path_type)
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
if spec is not None:
|
||||
jit_einsum = jax.named_call(jit_einsum, name=spec)
|
||||
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
|
||||
return jit_einsum(operand_arrays, contractions, precision,
|
||||
preferred_element_type, _dot_general, out_sharding)
|
||||
|
||||
|
||||
# Enable other modules to override einsum_contact_path.
|
||||
# Indexed by the type of the non constant dimension
|
||||
_poly_einsum_handlers = {} # type: ignore
|
||||
|
||||
def _default_poly_einsum_handler(*operands, **kwargs):
|
||||
dummy = collections.namedtuple('dummy', ['shape', 'dtype'])
|
||||
dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype)
|
||||
if hasattr(x, 'dtype') else x for x in operands]
|
||||
mapping = {id(d): i for i, d in enumerate(dummies)}
|
||||
out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs)
|
||||
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
|
||||
return contract_operands, contractions
|
||||
|
||||
@overload
|
||||
def einsum_path(
|
||||
subscripts: str, /,
|
||||
*operands: ArrayLike,
|
||||
optimize: bool | str | list[tuple[int, ...]] = ...,
|
||||
) -> tuple[list[tuple[int, ...]], Any]: ...
|
||||
|
||||
@overload
|
||||
def einsum_path(
|
||||
arr: ArrayLike,
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
optimize: bool | str | list[tuple[int, ...]] = ...,
|
||||
) -> tuple[list[tuple[int, ...]], Any]: ...
|
||||
|
||||
@export
|
||||
def einsum_path(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
optimize: bool | str | list[tuple[int, ...]] = 'auto'
|
||||
) -> tuple[list[tuple[int, ...]], Any]:
|
||||
"""Evaluates the optimal contraction path without evaluating the einsum.
|
||||
|
||||
JAX implementation of :func:`numpy.einsum_path`. This function calls into
|
||||
the opt_einsum_ package, and makes use of its optimization routines.
|
||||
|
||||
Args:
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"auto"``. Other options are ``True`` (same as ``"optimize"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"optimize"``,, ``"greedy"``, ``"eager"``, and others.
|
||||
|
||||
Returns:
|
||||
A tuple containing the path that may be passed to :func:`~jax.numpy.einsum`, and a
|
||||
printable object representing this optimal path.
|
||||
|
||||
Examples:
|
||||
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
|
||||
>>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3))
|
||||
>>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100))
|
||||
>>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5))
|
||||
>>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal")
|
||||
>>> print(path)
|
||||
[(1, 2), (0, 1)]
|
||||
>>> print(path_info)
|
||||
Complete contraction: ij,jk,kl->il
|
||||
Naive scaling: 4
|
||||
Optimized scaling: 3
|
||||
Naive FLOP count: 9.000e+3
|
||||
Optimized FLOP count: 3.060e+3
|
||||
Theoretical speedup: 2.941e+0
|
||||
Largest intermediate: 1.500e+1 elements
|
||||
--------------------------------------------------------------------------------
|
||||
scaling BLAS current remaining
|
||||
--------------------------------------------------------------------------------
|
||||
3 GEMM kl,jk->lj ij,lj->il
|
||||
3 GEMM lj,ij->il il->il
|
||||
|
||||
Use the computed path in :func:`~jax.numpy.einsum`:
|
||||
|
||||
>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path)
|
||||
Array([[-754, 324, -142, 82, 50],
|
||||
[ 408, -50, 87, -29, 7]], dtype=int32)
|
||||
|
||||
.. _opt_einsum: https://github.com/dgasmith/opt_einsum
|
||||
"""
|
||||
if optimize is True:
|
||||
optimize = 'optimal'
|
||||
elif optimize is False:
|
||||
optimize = Unoptimized()
|
||||
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
|
||||
|
||||
def _removechars(s, chars):
|
||||
return s.translate(str.maketrans(dict.fromkeys(chars)))
|
||||
|
||||
|
||||
def _einsum(
|
||||
operands: list[jax.Array],
|
||||
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
|
||||
precision,
|
||||
preferred_element_type,
|
||||
_dot_general=lax.dot_general,
|
||||
out_sharding=None,
|
||||
):
|
||||
if out_sharding is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_sharding only works when sharding_in_types "
|
||||
"config is True.")
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
"`out_sharding` argument of `einsum` only supports NamedSharding"
|
||||
" instances. Please file a bug if this is not enough for your use case.")
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
|
||||
else:
|
||||
output_weak_type = False
|
||||
|
||||
def sum(x, axes):
|
||||
if dtypes.result_type(x, preferred_element_type) != x.dtype:
|
||||
x = x.astype(preferred_element_type)
|
||||
return lax.reduce(x, np.array(0, x.dtype),
|
||||
lax.add if x.dtype != bool else lax.bitwise_or, axes)
|
||||
|
||||
def sum_uniques(operand, names, uniques):
|
||||
if uniques:
|
||||
axes = [names.index(name) for name in uniques]
|
||||
operand = sum(operand, axes)
|
||||
names = _removechars(names, uniques)
|
||||
return operand, names
|
||||
|
||||
def sum_repeats(operand, names, counts, keep_names):
|
||||
for name, count in counts.items():
|
||||
if count > 1:
|
||||
axes = [i for i, n in enumerate(names) if n == name]
|
||||
eye = lax_internal._delta(np.dtype('bool'), operand.shape, axes)
|
||||
operand = lax.select(eye, operand, zeros_like(operand))
|
||||
if name not in keep_names:
|
||||
operand = sum(operand, axes)
|
||||
names = names.replace(name, '')
|
||||
else:
|
||||
operand = sum(operand, axes[:-1])
|
||||
names = names.replace(name, '', count - 1)
|
||||
return operand, names
|
||||
|
||||
def filter_singleton_dims(operand, names, other_shape, other_names):
|
||||
eq = core.definitely_equal
|
||||
keep = [not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
|
||||
for i, j in enumerate(map(other_names.find, names))]
|
||||
sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))
|
||||
return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes)
|
||||
|
||||
for operand_indices, contracted_names_set, einstr in contractions:
|
||||
contracted_names = sorted(contracted_names_set)
|
||||
input_str, result_names = einstr.split('->')
|
||||
input_names = input_str.split(',')
|
||||
|
||||
# switch on the number of operands to be processed in this loop iteration.
|
||||
# every case here sets 'operand' and 'names'.
|
||||
if len(operand_indices) == 1:
|
||||
operand = operands.pop(operand_indices[0])
|
||||
names, = input_names
|
||||
counts = collections.Counter(names)
|
||||
|
||||
# sum out unique contracted indices with a single reduce-sum
|
||||
uniques = [name for name in contracted_names if counts[name] == 1]
|
||||
operand, names = sum_uniques(operand, names, uniques)
|
||||
|
||||
# for every repeated index, do a contraction against an identity matrix
|
||||
operand, names = sum_repeats(operand, names, counts, result_names)
|
||||
|
||||
elif len(operand_indices) == 2:
|
||||
lhs, rhs = map(operands.pop, operand_indices)
|
||||
lhs_names, rhs_names = input_names
|
||||
|
||||
# handle cases where one side of a contracting or batch dimension is 1
|
||||
# but its counterpart is not.
|
||||
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
|
||||
rhs_names)
|
||||
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
|
||||
lhs_names)
|
||||
|
||||
lhs_counts = collections.Counter(lhs_names)
|
||||
rhs_counts = collections.Counter(rhs_names)
|
||||
|
||||
# sum out unique contracted indices in lhs and rhs
|
||||
lhs_uniques = [name for name in contracted_names
|
||||
if lhs_counts[name] == 1 and rhs_counts[name] == 0]
|
||||
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
|
||||
|
||||
rhs_uniques = [name for name in contracted_names
|
||||
if rhs_counts[name] == 1 and lhs_counts[name] == 0]
|
||||
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
|
||||
|
||||
# for every repeated index, contract against an identity matrix
|
||||
lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts,
|
||||
result_names + rhs_names)
|
||||
rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts,
|
||||
result_names + lhs_names)
|
||||
|
||||
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
|
||||
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
|
||||
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
|
||||
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
|
||||
|
||||
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
|
||||
for n in batch_names)
|
||||
|
||||
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
|
||||
# due to opt_einsum
|
||||
assert config.dynamic_shapes.value or all(
|
||||
name in lhs_names and name in rhs_names and
|
||||
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
|
||||
for name in contracted_names), (
|
||||
"Incompatible reduction dimensions: "
|
||||
f"lhs.shape={lhs.shape} lhs_names={lhs_names} "
|
||||
f"rhs.shape={rhs.shape} rhs_names={rhs_names}")
|
||||
|
||||
# contract using dot_general
|
||||
batch_names_str = ''.join(batch_names)
|
||||
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
|
||||
for n in contracted_names)
|
||||
deleted_names = batch_names_str + ''.join(contracted_names)
|
||||
remaining_lhs_names = _removechars(lhs_names, deleted_names)
|
||||
remaining_rhs_names = _removechars(rhs_names, deleted_names)
|
||||
# Try both orders of lhs and rhs, in the hope that one of them means we
|
||||
# don't need an explicit transpose. opt_einsum likes to contract from
|
||||
# right to left, so we expect (rhs,lhs) to have the best chance of not
|
||||
# needing a transpose.
|
||||
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
|
||||
if names == result_names:
|
||||
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
|
||||
k_out_sharding = ({} if out_sharding is None else
|
||||
{'out_sharding': out_sharding})
|
||||
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
**k_out_sharding)
|
||||
else:
|
||||
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
|
||||
if (config.sharding_in_types.value and out_sharding is not None and
|
||||
names != result_names):
|
||||
spec = out_sharding.spec
|
||||
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
|
||||
dot_general_out_sharding = NamedSharding(out_sharding.mesh,
|
||||
P(*inverse_spec))
|
||||
else:
|
||||
dot_general_out_sharding = out_sharding # type: ignore
|
||||
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
|
||||
dot_general_out_sharding = ({} if dot_general_out_sharding is None else # type: ignore
|
||||
{'out_sharding': dot_general_out_sharding})
|
||||
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
**dot_general_out_sharding)
|
||||
else:
|
||||
raise NotImplementedError # if this is actually reachable, open an issue!
|
||||
|
||||
# the resulting 'operand' with axis labels 'names' should be a permutation
|
||||
# of the desired result
|
||||
assert len(names) == len(result_names) == len(set(names))
|
||||
assert set(names) == set(result_names)
|
||||
if names != result_names:
|
||||
perm = tuple(names.index(name) for name in result_names)
|
||||
operand = lax.transpose(operand, perm)
|
||||
operands.append(operand) # used in next iteration
|
||||
|
||||
return lax_internal._convert_element_type(operands[0], preferred_element_type,
|
||||
output_weak_type)
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis'))
|
||||
|
@ -78,8 +78,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
dtype as dtype,
|
||||
e as e,
|
||||
ediff1d as ediff1d,
|
||||
einsum as einsum,
|
||||
einsum_path as einsum_path,
|
||||
euler_gamma as euler_gamma,
|
||||
expand_dims as expand_dims,
|
||||
extract as extract,
|
||||
@ -208,6 +206,11 @@ from jax._src.numpy.array_creation import (
|
||||
zeros_like as zeros_like,
|
||||
)
|
||||
|
||||
from jax._src.numpy.einsum import (
|
||||
einsum as einsum,
|
||||
einsum_path as einsum_path,
|
||||
)
|
||||
|
||||
from jax._src.numpy.scalar_types import (
|
||||
bfloat16 as bfloat16,
|
||||
bool_ as bool, # Array API alias for bool_ # noqa: F401
|
||||
|
Loading…
x
Reference in New Issue
Block a user