Move jit to the callsite.

PiperOrigin-RevId: 589328135
This commit is contained in:
jax authors 2023-12-08 22:19:19 -08:00
parent c3bc459f69
commit 709564ab78

View File

@ -3495,10 +3495,11 @@ def einsum(
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
_einsum_computation = jax.named_call(
_einsum, name=spec) if spec is not None else _einsum
return _einsum_computation(operands, contractions, precision, # type: ignore[operator]
preferred_element_type, _dot_general)
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision, # type: ignore[operator]
preferred_element_type, _dot_general)
# Enable other modules to override einsum_contact_path.
@ -3523,7 +3524,6 @@ def _removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))
@partial(jit, static_argnums=(1, 2, 3, 4), inline=True)
def _einsum(
operands: Sequence,
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],