mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jnp.inner: add preferred_element_type argument
This commit is contained in:
parent
c1ec78f35c
commit
3386e54fe0
@ -3483,13 +3483,17 @@ def _einsum(
|
||||
|
||||
|
||||
@util._wraps(np.inner, lax_description=_PRECISION_DOC)
|
||||
@partial(jit, static_argnames=('precision',), inline=True)
|
||||
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
|
||||
def inner(
|
||||
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None
|
||||
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
|
||||
preferred_element_type: DType | None = None,
|
||||
) -> Array:
|
||||
if ndim(a) == 0 or ndim(b) == 0:
|
||||
return asarray(a) * asarray(b)
|
||||
return tensordot(a, b, (-1, -1), precision=precision)
|
||||
a = asarray(a, dtype=preferred_element_type)
|
||||
b = asarray(b, dtype=preferred_element_type)
|
||||
return a * b
|
||||
return tensordot(a, b, (-1, -1), precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
@util._wraps(np.outer, skip_params=['out'])
|
||||
|
@ -4890,7 +4890,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
ones_1d, ones_1d)
|
||||
|
||||
@jtu.sample_product(
|
||||
funcname=['matmul', 'dot', 'vdot', 'tensordot']
|
||||
funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot']
|
||||
)
|
||||
def testPreferredElementType(self, funcname):
|
||||
func = getattr(jnp, funcname)
|
||||
|
Loading…
x
Reference in New Issue
Block a user