diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index a7f832d32..2cbb113cc 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -55,6 +55,26 @@ _dtype_to_32bit_dtype = { np.dtype('complex128'): np.dtype('complex64'), } +# Note: we promote narrow types to float32 here for backward compatibility +# with earlier approaches. We might consider revisiting this, or perhaps +# tying the logic more closely to the type promotion lattice. +_dtype_to_inexact = { + np.dtype(k): np.dtype(v) for k, v in [ + ('bool', 'float32'), + ('uint8', 'float32'), ('int8', 'float32'), + ('uint16', 'float32'), ('int16', 'float32'), + ('uint32', 'float32'), ('int32', 'float32'), + ('uint64', 'float64'), ('int64', 'float64') + ] +} + + +def _to_inexact_dtype(dtype): + """Promotes a dtype into an inexact dtype, if it is not already one.""" + dtype = np.dtype(dtype) + return _dtype_to_inexact.get(dtype, dtype) + + @functools.lru_cache(maxsize=None) def _canonicalize_dtype(x64_enabled, dtype): """Convert from a dtype to a canonical dtype based on config.x64_enabled.""" diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 1a4a9afd5..7b646d933 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -270,17 +270,12 @@ def _promote_dtypes_inexact(*args): Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) - to_dtype_inexact = _to_inexact_dtype(to_dtype) + to_dtype_inexact = dtypes._to_inexact_dtype(to_dtype) weak_type = (weak_type and to_dtype == to_dtype_inexact) return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] -def _to_inexact_dtype(dtype): - """Promotes a dtype into an inexact dtype, if it is not already one.""" - return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_) - - def _complex_elem_type(dtype): """Returns the float type of the real/imaginary parts of a complex dtype.""" return np.abs(np.zeros((), dtype)).dtype diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3dca36331..5b26c3ada 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -789,7 +789,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32) t = out_dtype if out_dtype != jnp.bfloat16 else np.float32 return np_op(x_cast, axis, dtype=t, keepdims=keepdims) - np_fun = _promote_like_jnp(np_fun, inexact) + jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -828,7 +828,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): res = np_op(x_cast, axis, keepdims=keepdims) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) return res - np_fun = _promote_like_jnp(np_fun, inexact) + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] tol = {np.float16: 0.002} @@ -853,13 +853,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=np.ComplexWarning) def np_fun(x): x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) return res - np_fun = _promote_like_jnp(np_fun, inexact) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -893,13 +893,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): where = jtu.rand_bool(self.rng())(whereshape, np.bool_) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=np.ComplexWarning) def np_fun(x): x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) return res - np_fun = _promote_like_jnp(np_fun, inexact) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -932,15 +932,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): @jtu.ignore_warning(category=RuntimeWarning, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, - message="invalid value encountered in true_divide*") + message="invalid value encountered.*") + @jtu.ignore_warning(category=np.ComplexWarning) def np_fun(x): x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims, where=where) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) return res - np_fun = _promote_like_jnp(np_fun, inexact) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -3744,7 +3743,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = _promote_like_jnp(np_fun, inexact=True) tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE and numpy_version >= (1, 22) try: self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 72da856fc..f86e90e3d 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -30,6 +30,7 @@ import jax from jax import numpy as jnp from jax import lax from jax import scipy as jsp +from jax.tree_util import tree_map from jax._src import test_util as jtu from jax.scipy import special as lsp_special from jax.scipy import cluster as lsp_cluster @@ -181,8 +182,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): - return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, - return_sign=return_sign, b=scale_array) + res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, + return_sign=return_sign, b=scale_array) + if dtype == np.int32: + res = tree_map(lambda x: x.astype('float32'), res) + return res def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, @@ -191,8 +195,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): - return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, - return_sign=return_sign) + res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, + return_sign=return_sign) + if dtype == np.int32: + res = tree_map(lambda x: x.astype('float32'), res) + return res def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,