diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7a333f424..51839c734 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3072,7 +3072,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, @util._wraps(np.dot, lax_description=_PRECISION_DOC) -@partial(jit, static_argnames=('precision',), inline=True) +@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -3104,7 +3104,7 @@ def dot(a: ArrayLike, b: ArrayLike, *, @util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC) -@partial(jit, static_argnames=('precision',), inline=True) +@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7ed09de59..afabedc2d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -741,6 +741,13 @@ def assert_dot_precision(expected_precision, fun, *args): else: assert precision == expected_precision, msg +def assert_dot_preferred_element_type(expected, fun, *args, **kwargs): + jaxpr = api.make_jaxpr(partial(fun, **kwargs))(*args) + pref_eltypes = [eqn.params['preferred_element_type'] for eqn in iter_eqns(jaxpr.jaxpr) + if eqn.primitive == lax.dot_general_p] + for pref_eltype in pref_eltypes: + msg = f"Unexpected preferred_element_type: {expected} != {pref_eltype}" + assert expected == pref_eltype, msg def cases_from_gens(*gens): sizes = [1, 3, 10] diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b9f88727b..b815aa77b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4889,6 +4889,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): partial(jnp.inner, precision=HIGHEST), ones_1d, ones_1d) + @jtu.sample_product( + funcname=['matmul', 'dot', 'vdot', 'tensordot'] + ) + def testPreferredElementType(self, funcname): + func = getattr(jnp, funcname) + kwargs = dict(axes=0) if funcname == 'tensordot' else {} + + ones_i32 = np.ones(2, dtype='int32') + ones_f32 = np.ones(2, dtype='float32') + + with jax.numpy_dtype_promotion('strict'): + jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs) + jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs) + jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs, + preferred_element_type='bfloat16') + with jax.numpy_dtype_promotion('standard'): + jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs) + @jtu.sample_product( [dict(shape=shape, varargs=varargs, axis=axis) for shape in [(10,), (10, 15), (10, 15, 20)]