mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Move jit to the callsite.
PiperOrigin-RevId: 589328135
This commit is contained in:
parent
c3bc459f69
commit
709564ab78
@ -3495,10 +3495,11 @@ def einsum(
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
_einsum_computation = jax.named_call(
|
||||
_einsum, name=spec) if spec is not None else _einsum
|
||||
return _einsum_computation(operands, contractions, precision, # type: ignore[operator]
|
||||
preferred_element_type, _dot_general)
|
||||
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
|
||||
if spec is not None:
|
||||
einsum = jax.named_call(einsum, name=spec)
|
||||
return einsum(operands, contractions, precision, # type: ignore[operator]
|
||||
preferred_element_type, _dot_general)
|
||||
|
||||
|
||||
# Enable other modules to override einsum_contact_path.
|
||||
@ -3523,7 +3524,6 @@ def _removechars(s, chars):
|
||||
return s.translate(str.maketrans(dict.fromkeys(chars)))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3, 4), inline=True)
|
||||
def _einsum(
|
||||
operands: Sequence,
|
||||
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user