Add missing preferred_element_type tests

Followup to https://github.com/google/jax/pull/17506
This commit is contained in:
Jake VanderPlas 2023-09-08 13:07:37 -07:00
parent 34ba4f53ff
commit 9289f3250b
3 changed files with 27 additions and 2 deletions

View File

@ -3072,7 +3072,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
@util._wraps(np.dot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def dot(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
@ -3104,7 +3104,7 @@ def dot(a: ArrayLike, b: ArrayLike, *,
@util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def matmul(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,

View File

@ -741,6 +741,13 @@ def assert_dot_precision(expected_precision, fun, *args):
else:
assert precision == expected_precision, msg
def assert_dot_preferred_element_type(expected, fun, *args, **kwargs):
jaxpr = api.make_jaxpr(partial(fun, **kwargs))(*args)
pref_eltypes = [eqn.params['preferred_element_type'] for eqn in iter_eqns(jaxpr.jaxpr)
if eqn.primitive == lax.dot_general_p]
for pref_eltype in pref_eltypes:
msg = f"Unexpected preferred_element_type: {expected} != {pref_eltype}"
assert expected == pref_eltype, msg
def cases_from_gens(*gens):
sizes = [1, 3, 10]

View File

@ -4889,6 +4889,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
partial(jnp.inner, precision=HIGHEST),
ones_1d, ones_1d)
@jtu.sample_product(
funcname=['matmul', 'dot', 'vdot', 'tensordot']
)
def testPreferredElementType(self, funcname):
func = getattr(jnp, funcname)
kwargs = dict(axes=0) if funcname == 'tensordot' else {}
ones_i32 = np.ones(2, dtype='int32')
ones_f32 = np.ones(2, dtype='float32')
with jax.numpy_dtype_promotion('strict'):
jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs)
jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs)
jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs,
preferred_element_type='bfloat16')
with jax.numpy_dtype_promotion('standard'):
jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs)
@jtu.sample_product(
[dict(shape=shape, varargs=varargs, axis=axis)
for shape in [(10,), (10, 15), (10, 15, 20)]