Merge pull request #26243 from jakevdp:einsum-asarray

PiperOrigin-RevId: 722455518
This commit is contained in:
jax authors 2025-02-02 17:42:47 -08:00
commit 57fa37214c

View File

@ -9744,10 +9744,11 @@ def einsum(
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision,
jit_einsum = jax.named_call(jit_einsum, name=spec)
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
return jit_einsum(operand_arrays, contractions, precision,
preferred_element_type, _dot_general, out_sharding)
@ -9843,7 +9844,7 @@ def _removechars(s, chars):
def _einsum(
operands: Sequence,
operands: list[jax.Array],
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
precision,
preferred_element_type,
@ -9859,7 +9860,6 @@ def _einsum(
"`out_sharding` argument of `einsum` only supports NamedSharding"
" instances. Please file a bug if this is not enough for your use case.")
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
operands = list(map(asarray, operands))
if preferred_element_type is None:
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
else: