mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #21148 from jakevdp:einsum-path
PiperOrigin-RevId: 632470123
This commit is contained in:
commit
c2d78abfa3
@ -4260,12 +4260,17 @@ def tensordot(a: ArrayLike, b: ArrayLike,
|
||||
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
|
||||
|
||||
|
||||
class Unoptimized(opt_einsum.paths.PathOptimizer):
|
||||
"""Unoptimized path for einsum."""
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
return [(0, 1)] * (len(inputs) - 1)
|
||||
|
||||
@overload
|
||||
def einsum(
|
||||
subscript: str, /,
|
||||
*operands: ArrayLike,
|
||||
out: None = None,
|
||||
optimize: str | bool = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -4277,7 +4282,7 @@ def einsum(
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
out: None = None,
|
||||
optimize: str | bool = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -4287,7 +4292,7 @@ def einsum(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
out: None = None,
|
||||
optimize: str | bool = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -4305,9 +4310,12 @@ def einsum(
|
||||
Args:
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: determine whether to optimize the order of computation. In JAX
|
||||
this defaults to ``"optimize"`` which produces optimized expressions via
|
||||
the opt_einsum_ package.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"optimal"`` which produces optimized expressions via the opt_einsum_
|
||||
package. Other options are ``True`` (same as ``"optimal"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
be a pre-computed path (see :func:`~jax.numpy.einsum_path`)
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
@ -4321,6 +4329,9 @@ def einsum(
|
||||
Returns:
|
||||
array containing the result of the einstein summation.
|
||||
|
||||
See also:
|
||||
:func:`jax.numpy.einsum_path`
|
||||
|
||||
Examples:
|
||||
The mechanics of ``einsum`` are perhaps best demonstrated by example. Here we
|
||||
show how to use ``einsum`` to compute a number of quantities from one or more
|
||||
@ -4498,7 +4509,7 @@ def einsum(
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
|
||||
spec = operands[0] if isinstance(operands[0], str) else None
|
||||
optimize = 'optimal' if optimize is True else optimize
|
||||
path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize
|
||||
|
||||
# Allow handling of shape polymorphism
|
||||
non_constant_dim_types = {
|
||||
@ -4512,7 +4523,7 @@ def einsum(
|
||||
contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
|
||||
# using einsum_call=True here is an internal api for opt_einsum... sorry
|
||||
operands, contractions = contract_path(
|
||||
*operands, einsum_call=True, use_blas=True, optimize=optimize)
|
||||
*operands, einsum_call=True, use_blas=True, optimize=path_type)
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
|
@ -286,7 +286,7 @@ def einsum(
|
||||
subscript: str, /,
|
||||
*operands: ArrayLike,
|
||||
out: None = ...,
|
||||
optimize: Union[str, builtins.bool] = "optimal",
|
||||
optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ...,
|
||||
precision: PrecisionLike = ...,
|
||||
preferred_element_type: Optional[DTypeLike] = ...,
|
||||
_use_xeinsum: builtins.bool = False,
|
||||
@ -299,7 +299,7 @@ def einsum(
|
||||
axes: Sequence[Any], /,
|
||||
*operands: Union[ArrayLike, Sequence[Any]],
|
||||
out: None = ...,
|
||||
optimize: Union[str, builtins.bool] = "optimal",
|
||||
optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ...,
|
||||
precision: PrecisionLike = ...,
|
||||
preferred_element_type: Optional[DTypeLike] = ...,
|
||||
_use_xeinsum: builtins.bool = False,
|
||||
@ -310,7 +310,7 @@ def einsum(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
out: None = ...,
|
||||
optimize: Union[str, builtins.bool] = ...,
|
||||
optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ...,
|
||||
precision: PrecisionLike = ...,
|
||||
preferred_element_type: Optional[DTypeLike] = ...,
|
||||
_use_xeinsum: builtins.bool = ...,
|
||||
|
@ -392,6 +392,25 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
f_np = jtu.promote_like_jnp(partial(np.einsum, 'a,a->a'))
|
||||
self._CheckAgainstNumpy(f_np, f_jax, args_maker, check_dtypes=True)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
{'signature': 'i->', 'shapes': [(3,)]},
|
||||
{'signature': 'ii->i', 'shapes': [(4, 4)]},
|
||||
{'signature': 'ij,jk', 'shapes': [(3, 4), (4, 3)]},
|
||||
{'signature': 'ij,jkl,klm', 'shapes': [(2, 2), (2, 3, 4), (3, 4, 2)]},
|
||||
],
|
||||
optimize=[True, False, 'optimal', 'auto', 'greedy', 'eager'],
|
||||
dtype=[np.dtype('float32')],
|
||||
)
|
||||
@jtu.skip_on_devices('tpu')
|
||||
def test_einsum_optimization_modes(self, signature, shapes, optimize, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
jnp_fun = partial(jnp.einsum, signature, optimize=optimize)
|
||||
np_fun = partial(np.einsum, signature)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user