From 4e30a08e8458c84c1a581da33e3be2e44d71f79c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 31 Jan 2025 11:59:45 -0800 Subject: [PATCH] Avoid call to asarray in jnp.einsum --- jax/_src/numpy/lax_numpy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4203547c9..220fb97f1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9744,11 +9744,12 @@ def einsum( contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) + jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) if spec is not None: - einsum = jax.named_call(einsum, name=spec) - return einsum(operands, contractions, precision, - preferred_element_type, _dot_general, out_sharding) + jit_einsum = jax.named_call(jit_einsum, name=spec) + operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands)) + return jit_einsum(operand_arrays, contractions, precision, + preferred_element_type, _dot_general, out_sharding) # Enable other modules to override einsum_contact_path. @@ -9843,7 +9844,7 @@ def _removechars(s, chars): def _einsum( - operands: Sequence, + operands: list[jax.Array], contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]], precision, preferred_element_type, @@ -9859,7 +9860,6 @@ def _einsum( "`out_sharding` argument of `einsum` only supports NamedSharding" " instances. Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") - operands = list(map(asarray, operands)) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True) else: