mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #26243 from jakevdp:einsum-asarray
PiperOrigin-RevId: 722455518
This commit is contained in:
commit
57fa37214c
@ -9744,10 +9744,11 @@ def einsum(
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
if spec is not None:
|
||||
einsum = jax.named_call(einsum, name=spec)
|
||||
return einsum(operands, contractions, precision,
|
||||
jit_einsum = jax.named_call(jit_einsum, name=spec)
|
||||
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
|
||||
return jit_einsum(operand_arrays, contractions, precision,
|
||||
preferred_element_type, _dot_general, out_sharding)
|
||||
|
||||
|
||||
@ -9843,7 +9844,7 @@ def _removechars(s, chars):
|
||||
|
||||
|
||||
def _einsum(
|
||||
operands: Sequence,
|
||||
operands: list[jax.Array],
|
||||
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
|
||||
precision,
|
||||
preferred_element_type,
|
||||
@ -9859,7 +9860,6 @@ def _einsum(
|
||||
"`out_sharding` argument of `einsum` only supports NamedSharding"
|
||||
" instances. Please file a bug if this is not enough for your use case.")
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
|
||||
operands = list(map(asarray, operands))
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user