jnp.inner: add preferred_element_type argument

This commit is contained in:
Jake VanderPlas 2023-09-14 16:40:19 -07:00
parent c1ec78f35c
commit 3386e54fe0
2 changed files with 9 additions and 5 deletions

View File

@ -3483,13 +3483,17 @@ def _einsum(
@util._wraps(np.inner, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def inner(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DType | None = None,
) -> Array:
if ndim(a) == 0 or ndim(b) == 0:
return asarray(a) * asarray(b)
return tensordot(a, b, (-1, -1), precision=precision)
a = asarray(a, dtype=preferred_element_type)
b = asarray(b, dtype=preferred_element_type)
return a * b
return tensordot(a, b, (-1, -1), precision=precision,
preferred_element_type=preferred_element_type)
@util._wraps(np.outer, skip_params=['out'])

View File

@ -4890,7 +4890,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
ones_1d, ones_1d)
@jtu.sample_product(
funcname=['matmul', 'dot', 'vdot', 'tensordot']
funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot']
)
def testPreferredElementType(self, funcname):
func = getattr(jnp, funcname)