mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add missing preferred_element_type tests
Followup to https://github.com/google/jax/pull/17506
This commit is contained in:
parent
34ba4f53ff
commit
9289f3250b
@ -3072,7 +3072,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
|
||||
|
||||
|
||||
@util._wraps(np.dot, lax_description=_PRECISION_DOC)
|
||||
@partial(jit, static_argnames=('precision',), inline=True)
|
||||
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
|
||||
def dot(a: ArrayLike, b: ArrayLike, *,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
@ -3104,7 +3104,7 @@ def dot(a: ArrayLike, b: ArrayLike, *,
|
||||
|
||||
|
||||
@util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC)
|
||||
@partial(jit, static_argnames=('precision',), inline=True)
|
||||
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
|
||||
def matmul(a: ArrayLike, b: ArrayLike, *,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
|
@ -741,6 +741,13 @@ def assert_dot_precision(expected_precision, fun, *args):
|
||||
else:
|
||||
assert precision == expected_precision, msg
|
||||
|
||||
def assert_dot_preferred_element_type(expected, fun, *args, **kwargs):
|
||||
jaxpr = api.make_jaxpr(partial(fun, **kwargs))(*args)
|
||||
pref_eltypes = [eqn.params['preferred_element_type'] for eqn in iter_eqns(jaxpr.jaxpr)
|
||||
if eqn.primitive == lax.dot_general_p]
|
||||
for pref_eltype in pref_eltypes:
|
||||
msg = f"Unexpected preferred_element_type: {expected} != {pref_eltype}"
|
||||
assert expected == pref_eltype, msg
|
||||
|
||||
def cases_from_gens(*gens):
|
||||
sizes = [1, 3, 10]
|
||||
|
@ -4889,6 +4889,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
partial(jnp.inner, precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
|
||||
@jtu.sample_product(
|
||||
funcname=['matmul', 'dot', 'vdot', 'tensordot']
|
||||
)
|
||||
def testPreferredElementType(self, funcname):
|
||||
func = getattr(jnp, funcname)
|
||||
kwargs = dict(axes=0) if funcname == 'tensordot' else {}
|
||||
|
||||
ones_i32 = np.ones(2, dtype='int32')
|
||||
ones_f32 = np.ones(2, dtype='float32')
|
||||
|
||||
with jax.numpy_dtype_promotion('strict'):
|
||||
jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs)
|
||||
jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs)
|
||||
jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs,
|
||||
preferred_element_type='bfloat16')
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, varargs=varargs, axis=axis)
|
||||
for shape in [(10,), (10, 15), (10, 15, 20)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user