mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jnp.einsum: default to optimize='auto'
This commit is contained in:
parent
c39e38fe5a
commit
2f7204fff6
@ -19,6 +19,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Changes:
|
||||
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
|
||||
supported version until June 2025.
|
||||
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
|
||||
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
|
||||
the case of many arguments ({jax-issue}`#25214`).
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
|
@ -9503,7 +9503,7 @@ def einsum(
|
||||
subscript: str, /,
|
||||
*operands: ArrayLike,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -9516,7 +9516,7 @@ def einsum(
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -9528,7 +9528,7 @@ def einsum(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -9548,10 +9548,10 @@ def einsum(
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"optimal"`` which produces optimized expressions via the opt_einsum_
|
||||
to ``"auto"`` which produces optimized expressions via the opt_einsum_
|
||||
package. Other options are ``True`` (same as ``"optimal"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
|
Loading…
x
Reference in New Issue
Block a user