From 2f7204fff6e223dc6b676c36084323b8f51a0895 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 6 Jan 2025 08:21:59 -0800 Subject: [PATCH] jnp.einsum: default to optimize='auto' --- CHANGELOG.md | 3 +++ jax/_src/numpy/lax_numpy.py | 10 +++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ffdafbb1..356e16bd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes: * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025. + * {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than + `optimize='optimal'`. This avoids exponentially-scaling trace-time in + the case of many arguments ({jax-issue}`#25214`). * New Features * {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f23261252..2681cbc81 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9503,7 +9503,7 @@ def einsum( subscript: str, /, *operands: ArrayLike, out: None = None, - optimize: str | bool | list[tuple[int, ...]] = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "auto", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -9516,7 +9516,7 @@ def einsum( axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], out: None = None, - optimize: str | bool | list[tuple[int, ...]] = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "auto", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -9528,7 +9528,7 @@ def einsum( subscripts, /, *operands, out: None = None, - optimize: str | bool | list[tuple[int, ...]] = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "auto", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -9548,10 +9548,10 @@ def einsum( subscripts: string containing axes names separated by commas. *operands: sequence of one or more arrays corresponding to the subscripts. optimize: specify how to optimize the order of computation. In JAX this defaults - to ``"optimal"`` which produces optimized expressions via the opt_einsum_ + to ``"auto"`` which produces optimized expressions via the opt_einsum_ package. Other options are ``True`` (same as ``"optimal"``), ``False`` (unoptimized), or any string supported by ``opt_einsum``, which - includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also + includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also be a pre-computed path (see :func:`~jax.numpy.einsum_path`). precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,